Code quality (#114)

* to color_eyre::Result 1st part

* tests and seconds batch

* third batch

* rln fixes + multiplier

* rln-wasm, assert rln, multiplier

* io to color_eyre

* fmt + clippy

* fix lint

* temporary fix of `ark-circom`

* fix ci after merge

* fmt

* fix rln tests

* minor

* fix tests

* imports

* requested change

* report + commented line + requested change

* requested changes

* fix build

* lint fixes

* better comments

---------

Co-authored-by: tyshkor <tyshko1@gmail.com>
This commit is contained in:
tyshko-rostyslav 2023-02-27 07:16:16 +01:00 committed by GitHub
parent 62018b4eba
commit 55b00fd653
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 606 additions and 535 deletions

View File

@ -31,12 +31,12 @@ impl<'a> From<&Buffer> for &'a [u8] {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn new_circuit(ctx: *mut *mut Multiplier) -> bool {
println!("multiplier ffi: new");
let mul = Multiplier::new();
unsafe { *ctx = Box::into_raw(Box::new(mul)) };
true
if let Ok(mul) = Multiplier::new() {
unsafe { *ctx = Box::into_raw(Box::new(mul)) };
true
} else {
false
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]

View File

@ -1,6 +1,6 @@
use ark_circom::{CircomBuilder, CircomConfig};
use ark_std::rand::thread_rng;
use color_eyre::Result;
use color_eyre::{Report, Result};
use ark_bn254::Bn254;
use ark_groth16::{
@ -25,17 +25,18 @@ fn groth16_proof_example() -> Result<()> {
let circom = builder.build()?;
let inputs = circom.get_public_inputs().unwrap();
let inputs = circom
.get_public_inputs()
.ok_or(Report::msg("no public inputs"))?;
let proof = prove(circom, &params, &mut rng)?;
let pvk = prepare_verifying_key(&params.vk);
let verified = verify_proof(&pvk, &proof, &inputs)?;
assert!(verified);
Ok(())
match verify_proof(&pvk, &proof, &inputs) {
Ok(_) => Ok(()),
Err(_) => Err(Report::msg("not verified")),
}
}
fn main() {

View File

@ -7,9 +7,8 @@ use ark_groth16::{
Proof, ProvingKey,
};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
// , SerializationError};
use std::io::{self, Read, Write};
use color_eyre::{Report, Result};
use std::io::{Read, Write};
pub struct Multiplier {
circom: CircomCircuit<Bn254>,
@ -18,12 +17,11 @@ pub struct Multiplier {
impl Multiplier {
// TODO Break this apart here
pub fn new() -> Multiplier {
pub fn new() -> Result<Multiplier> {
let cfg = CircomConfig::<Bn254>::new(
"./resources/circom2_multiplier2.wasm",
"./resources/circom2_multiplier2.r1cs",
)
.unwrap();
)?;
let mut builder = CircomBuilder::new(cfg);
builder.push_input("a", 3);
@ -34,40 +32,41 @@ impl Multiplier {
let mut rng = thread_rng();
let params = generate_random_parameters::<Bn254, _, _>(circom, &mut rng).unwrap();
let params = generate_random_parameters::<Bn254, _, _>(circom, &mut rng)?;
let circom = builder.build().unwrap();
let circom = builder.build()?;
//let inputs = circom.get_public_inputs().unwrap();
Multiplier { circom, params }
Ok(Multiplier { circom, params })
}
// TODO Input Read
pub fn prove<W: Write>(&self, result_data: W) -> io::Result<()> {
pub fn prove<W: Write>(&self, result_data: W) -> Result<()> {
let mut rng = thread_rng();
// XXX: There's probably a better way to do this
let circom = self.circom.clone();
let params = self.params.clone();
let proof = prove(circom, &params, &mut rng).unwrap();
let proof = prove(circom, &params, &mut rng)?;
// XXX: Unclear if this is different from other serialization(s)
proof.serialize(result_data).unwrap();
proof.serialize(result_data)?;
Ok(())
}
pub fn verify<R: Read>(&self, input_data: R) -> io::Result<bool> {
let proof = Proof::deserialize(input_data).unwrap();
pub fn verify<R: Read>(&self, input_data: R) -> Result<bool> {
let proof = Proof::deserialize(input_data)?;
let pvk = prepare_verifying_key(&self.params.vk);
// XXX Part of input data?
let inputs = self.circom.get_public_inputs().unwrap();
let inputs = self
.circom
.get_public_inputs()
.ok_or(Report::msg("no public inputs"))?;
let verified = verify_proof(&pvk, &proof, &inputs).unwrap();
let verified = verify_proof(&pvk, &proof, &inputs)?;
Ok(verified)
}
@ -75,6 +74,6 @@ impl Multiplier {
impl Default for Multiplier {
fn default() -> Self {
Self::new()
Self::new().unwrap()
}
}

View File

@ -4,8 +4,7 @@ mod tests {
#[test]
fn multiplier_proof() {
let mul = Multiplier::new();
//let inputs = mul.circom.get_public_inputs().unwrap();
let mul = Multiplier::new().unwrap();
let mut output_data: Vec<u8> = Vec::new();
let _ = mul.prove(&mut output_data);

View File

@ -20,6 +20,7 @@ wasm-bindgen = "0.2.63"
serde-wasm-bindgen = "0.4"
js-sys = "0.3.59"
serde_json = "1.0.85"
anyhow = "1.0.69"
# The `console_error_panic_hook` crate provides better debugging of panics by
# logging them with `console.error`. This is great for development, but requires

View File

@ -22,21 +22,30 @@ pub struct RLNWrapper {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = newRLN)]
pub fn wasm_new(tree_height: usize, zkey: Uint8Array, vk: Uint8Array) -> *mut RLNWrapper {
let instance = RLN::new_with_params(tree_height, zkey.to_vec(), vk.to_vec());
pub fn wasm_new(
tree_height: usize,
zkey: Uint8Array,
vk: Uint8Array,
) -> Result<*mut RLNWrapper, String> {
let instance = RLN::new_with_params(tree_height, zkey.to_vec(), vk.to_vec())
.map_err(|err| format!("{:#?}", err))?;
let wrapper = RLNWrapper { instance };
Box::into_raw(Box::new(wrapper))
Ok(Box::into_raw(Box::new(wrapper)))
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = getSerializedRLNWitness)]
pub fn wasm_get_serialized_rln_witness(ctx: *mut RLNWrapper, input: Uint8Array) -> Uint8Array {
pub fn wasm_get_serialized_rln_witness(
ctx: *mut RLNWrapper,
input: Uint8Array,
) -> Result<Uint8Array, String> {
let wrapper = unsafe { &mut *ctx };
let rln_witness = wrapper
.instance
.get_serialized_rln_witness(&input.to_vec()[..]);
.get_serialized_rln_witness(&input.to_vec()[..])
.map_err(|err| format!("{:#?}", err))?;
Uint8Array::from(&rln_witness[..])
Ok(Uint8Array::from(&rln_witness[..]))
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
@ -86,16 +95,18 @@ pub fn wasm_init_tree_with_leaves(ctx: *mut RLNWrapper, input: Uint8Array) -> Re
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = RLNWitnessToJson)]
pub fn rln_witness_to_json(ctx: *mut RLNWrapper, serialized_witness: Uint8Array) -> Object {
pub fn rln_witness_to_json(
ctx: *mut RLNWrapper,
serialized_witness: Uint8Array,
) -> Result<Object, String> {
let wrapper = unsafe { &mut *ctx };
let inputs = wrapper
.instance
.get_rln_witness_json(&serialized_witness.to_vec()[..])
.unwrap();
.map_err(|err| err.to_string())?;
let js_value = serde_wasm_bindgen::to_value(&inputs).unwrap();
let obj = Object::from_entries(&js_value);
obj.unwrap()
let js_value = serde_wasm_bindgen::to_value(&inputs).map_err(|err| err.to_string())?;
Object::from_entries(&js_value).map_err(|err| format!("{:#?}", err))
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
@ -107,17 +118,18 @@ pub fn generate_rln_proof_with_witness(
) -> Result<Uint8Array, String> {
let wrapper = unsafe { &mut *ctx };
let witness_vec: Vec<BigInt> = calculated_witness
.iter()
.map(|v| {
let mut witness_vec: Vec<BigInt> = vec![];
for v in calculated_witness {
witness_vec.push(
v.to_string(10)
.unwrap()
.map_err(|err| format!("{:#?}", err))?
.as_string()
.unwrap()
.ok_or("not a string error")?
.parse::<BigInt>()
.unwrap()
})
.collect();
.map_err(|err| format!("{:#?}", err))?,
);
}
let mut output_data: Vec<u8> = Vec::new();

View File

@ -8,10 +8,11 @@ use ark_circom::read_zkey;
use ark_groth16::{ProvingKey, VerifyingKey};
use ark_relations::r1cs::ConstraintMatrices;
use cfg_if::cfg_if;
use color_eyre::{Report, Result};
use num_bigint::BigUint;
use serde_json::Value;
use std::fs::File;
use std::io::{Cursor, Error, ErrorKind, Result};
use std::io::Cursor;
use std::path::Path;
use std::str::FromStr;
@ -57,7 +58,7 @@ pub fn zkey_from_raw(zkey_data: &Vec<u8>) -> Result<(ProvingKey<Curve>, Constrai
let proving_key_and_matrices = read_zkey(&mut c)?;
Ok(proving_key_and_matrices)
} else {
Err(Error::new(ErrorKind::NotFound, "No proving key found!"))
Err(Report::msg("No proving key found!"))
}
}
@ -71,7 +72,7 @@ pub fn zkey_from_folder(
let proving_key_and_matrices = read_zkey(&mut file)?;
Ok(proving_key_and_matrices)
} else {
Err(Error::new(ErrorKind::NotFound, "No proving key found!"))
Err(Report::msg("No proving key found!"))
}
}
@ -80,17 +81,14 @@ pub fn vk_from_raw(vk_data: &Vec<u8>, zkey_data: &Vec<u8>) -> Result<VerifyingKe
let verifying_key: VerifyingKey<Curve>;
if !vk_data.is_empty() {
verifying_key = vk_from_vector(vk_data);
verifying_key = vk_from_vector(vk_data)?;
Ok(verifying_key)
} else if !zkey_data.is_empty() {
let (proving_key, _matrices) = zkey_from_raw(zkey_data)?;
verifying_key = proving_key.vk;
Ok(verifying_key)
} else {
Err(Error::new(
ErrorKind::NotFound,
"No proving/verification key found!",
))
Err(Report::msg("No proving/verification key found!"))
}
}
@ -102,17 +100,13 @@ pub fn vk_from_folder(resources_folder: &str) -> Result<VerifyingKey<Curve>> {
let verifying_key: VerifyingKey<Curve>;
if Path::new(&vk_path).exists() {
verifying_key = vk_from_json(&vk_path);
Ok(verifying_key)
vk_from_json(&vk_path)
} else if Path::new(&zkey_path).exists() {
let (proving_key, _matrices) = zkey_from_folder(resources_folder)?;
verifying_key = proving_key.vk;
Ok(verifying_key)
} else {
Err(Error::new(
ErrorKind::NotFound,
"No proving/verification key found!",
))
Err(Report::msg("No proving/verification key found!"))
}
}
@ -121,129 +115,146 @@ static WITNESS_CALCULATOR: OnceCell<Mutex<WitnessCalculator>> = OnceCell::new();
// Initializes the witness calculator using a bytes vector
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> &'static Mutex<WitnessCalculator> {
WITNESS_CALCULATOR.get_or_init(|| {
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> Result<&'static Mutex<WitnessCalculator>> {
WITNESS_CALCULATOR.get_or_try_init(|| {
let store = Store::default();
let module = Module::new(&store, wasm_buffer).unwrap();
let result =
WitnessCalculator::from_module(module).expect("Failed to create witness calculator");
Mutex::new(result)
let module = Module::new(&store, wasm_buffer)?;
let result = WitnessCalculator::from_module(module)?;
Ok::<Mutex<WitnessCalculator>, Report>(Mutex::new(result))
})
}
// Initializes the witness calculator
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_folder(resources_folder: &str) -> &'static Mutex<WitnessCalculator> {
pub fn circom_from_folder(resources_folder: &str) -> Result<&'static Mutex<WitnessCalculator>> {
// We read the wasm file
let wasm_path = format!("{resources_folder}{WASM_FILENAME}");
let wasm_buffer = std::fs::read(wasm_path).unwrap();
let wasm_buffer = std::fs::read(wasm_path)?;
circom_from_raw(wasm_buffer)
}
// The following function implementations are taken/adapted from https://github.com/gakonst/ark-circom/blob/1732e15d6313fe176b0b1abb858ac9e095d0dbd7/src/zkey.rs
// Utilities to convert a json verification key in a groth16::VerificationKey
fn fq_from_str(s: &str) -> Fq {
Fq::try_from(BigUint::from_str(s).unwrap()).unwrap()
fn fq_from_str(s: &str) -> Result<Fq> {
Ok(Fq::try_from(BigUint::from_str(s)?)?)
}
// Extracts the element in G1 corresponding to its JSON serialization
fn json_to_g1(json: &Value, key: &str) -> G1Affine {
fn json_to_g1(json: &Value, key: &str) -> Result<G1Affine> {
let els: Vec<String> = json
.get(key)
.unwrap()
.ok_or(Report::msg("no json value"))?
.as_array()
.unwrap()
.ok_or(Report::msg("value not an array"))?
.iter()
.map(|i| i.as_str().unwrap().to_string())
.collect();
G1Affine::from(G1Projective::new(
fq_from_str(&els[0]),
fq_from_str(&els[1]),
fq_from_str(&els[2]),
))
.map(|i| i.as_str().ok_or(Report::msg("element is not a string")))
.map(|x| x.map(|v| v.to_owned()))
.collect::<Result<Vec<String>>>()?;
Ok(G1Affine::from(G1Projective::new(
fq_from_str(&els[0])?,
fq_from_str(&els[1])?,
fq_from_str(&els[2])?,
)))
}
// Extracts the vector of G1 elements corresponding to its JSON serialization
fn json_to_g1_vec(json: &Value, key: &str) -> Vec<G1Affine> {
fn json_to_g1_vec(json: &Value, key: &str) -> Result<Vec<G1Affine>> {
let els: Vec<Vec<String>> = json
.get(key)
.unwrap()
.ok_or(Report::msg("no json value"))?
.as_array()
.unwrap()
.ok_or(Report::msg("value not an array"))?
.iter()
.map(|i| {
i.as_array()
.unwrap()
.iter()
.map(|x| x.as_str().unwrap().to_string())
.collect::<Vec<String>>()
.ok_or(Report::msg("element is not an array"))
.and_then(|array| {
array
.iter()
.map(|x| x.as_str().ok_or(Report::msg("element is not a string")))
.map(|x| x.map(|v| v.to_owned()))
.collect::<Result<Vec<String>>>()
})
})
.collect();
.collect::<Result<Vec<Vec<String>>>>()?;
els.iter()
.map(|coords| {
G1Affine::from(G1Projective::new(
fq_from_str(&coords[0]),
fq_from_str(&coords[1]),
fq_from_str(&coords[2]),
))
})
.collect()
let mut res = vec![];
for coords in els {
res.push(G1Affine::from(G1Projective::new(
fq_from_str(&coords[0])?,
fq_from_str(&coords[1])?,
fq_from_str(&coords[2])?,
)))
}
Ok(res)
}
// Extracts the element in G2 corresponding to its JSON serialization
fn json_to_g2(json: &Value, key: &str) -> G2Affine {
fn json_to_g2(json: &Value, key: &str) -> Result<G2Affine> {
let els: Vec<Vec<String>> = json
.get(key)
.unwrap()
.ok_or(Report::msg("no json value"))?
.as_array()
.unwrap()
.ok_or(Report::msg("value not an array"))?
.iter()
.map(|i| {
i.as_array()
.unwrap()
.iter()
.map(|x| x.as_str().unwrap().to_string())
.collect::<Vec<String>>()
.ok_or(Report::msg("element is not an array"))
.and_then(|array| {
array
.iter()
.map(|x| x.as_str().ok_or(Report::msg("element is not a string")))
.map(|x| x.map(|v| v.to_owned()))
.collect::<Result<Vec<String>>>()
})
})
.collect();
.collect::<Result<Vec<Vec<String>>>>()?;
let x = Fq2::new(fq_from_str(&els[0][0]), fq_from_str(&els[0][1]));
let y = Fq2::new(fq_from_str(&els[1][0]), fq_from_str(&els[1][1]));
let z = Fq2::new(fq_from_str(&els[2][0]), fq_from_str(&els[2][1]));
G2Affine::from(G2Projective::new(x, y, z))
let x = Fq2::new(fq_from_str(&els[0][0])?, fq_from_str(&els[0][1])?);
let y = Fq2::new(fq_from_str(&els[1][0])?, fq_from_str(&els[1][1])?);
let z = Fq2::new(fq_from_str(&els[2][0])?, fq_from_str(&els[2][1])?);
Ok(G2Affine::from(G2Projective::new(x, y, z)))
}
// Converts JSON to a VerifyingKey
fn to_verifying_key(json: serde_json::Value) -> VerifyingKey<Curve> {
VerifyingKey {
alpha_g1: json_to_g1(&json, "vk_alpha_1"),
beta_g2: json_to_g2(&json, "vk_beta_2"),
gamma_g2: json_to_g2(&json, "vk_gamma_2"),
delta_g2: json_to_g2(&json, "vk_delta_2"),
gamma_abc_g1: json_to_g1_vec(&json, "IC"),
}
fn to_verifying_key(json: serde_json::Value) -> Result<VerifyingKey<Curve>> {
Ok(VerifyingKey {
alpha_g1: json_to_g1(&json, "vk_alpha_1")?,
beta_g2: json_to_g2(&json, "vk_beta_2")?,
gamma_g2: json_to_g2(&json, "vk_gamma_2")?,
delta_g2: json_to_g2(&json, "vk_delta_2")?,
gamma_abc_g1: json_to_g1_vec(&json, "IC")?,
})
}
// Computes the verification key from its JSON serialization
fn vk_from_json(vk_path: &str) -> VerifyingKey<Curve> {
let json = std::fs::read_to_string(vk_path).unwrap();
let json: Value = serde_json::from_str(&json).unwrap();
fn vk_from_json(vk_path: &str) -> Result<VerifyingKey<Curve>> {
let json = std::fs::read_to_string(vk_path)?;
let json: Value = serde_json::from_str(&json)?;
to_verifying_key(json)
}
// Computes the verification key from a bytes vector containing its JSON serialization
fn vk_from_vector(vk: &[u8]) -> VerifyingKey<Curve> {
let json = String::from_utf8(vk.to_vec()).expect("Found invalid UTF-8");
let json: Value = serde_json::from_str(&json).unwrap();
fn vk_from_vector(vk: &[u8]) -> Result<VerifyingKey<Curve>> {
let json = String::from_utf8(vk.to_vec())?;
let json: Value = serde_json::from_str(&json)?;
to_verifying_key(json)
}
// Checks verification key to be correct with respect to proving key
pub fn check_vk_from_zkey(resources_folder: &str, verifying_key: VerifyingKey<Curve>) {
let (proving_key, _matrices) = zkey_from_folder(resources_folder).unwrap();
assert_eq!(proving_key.vk, verifying_key);
pub fn check_vk_from_zkey(
resources_folder: &str,
verifying_key: VerifyingKey<Curve>,
) -> Result<()> {
let (proving_key, _matrices) = zkey_from_folder(resources_folder)?;
if proving_key.vk == verifying_key {
Ok(())
} else {
Err(Report::msg("verifying_keys are not equal"))
}
}

View File

@ -171,9 +171,12 @@ impl<'a> From<&Buffer> for &'a [u8] {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn new(tree_height: usize, input_buffer: *const Buffer, ctx: *mut *mut RLN) -> bool {
let rln = RLN::new(tree_height, input_buffer.process());
unsafe { *ctx = Box::into_raw(Box::new(rln)) };
true
if let Ok(rln) = RLN::new(tree_height, input_buffer.process()) {
unsafe { *ctx = Box::into_raw(Box::new(rln)) };
true
} else {
false
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
@ -185,14 +188,17 @@ pub extern "C" fn new_with_params(
vk_buffer: *const Buffer,
ctx: *mut *mut RLN,
) -> bool {
let rln = RLN::new_with_params(
if let Ok(rln) = RLN::new_with_params(
tree_height,
circom_buffer.process().to_vec(),
zkey_buffer.process().to_vec(),
vk_buffer.process().to_vec(),
);
unsafe { *ctx = Box::into_raw(Box::new(rln)) };
true
) {
unsafe { *ctx = Box::into_raw(Box::new(rln)) };
true
} else {
false
}
}
////////////////////////////////////////////////////////

View File

@ -8,7 +8,7 @@ use ark_groth16::{
use ark_relations::r1cs::ConstraintMatrices;
use ark_relations::r1cs::SynthesisError;
use ark_std::{rand::thread_rng, UniformRand};
use color_eyre::Result;
use color_eyre::{Report, Result};
use num_bigint::BigInt;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
@ -91,29 +91,29 @@ pub fn deserialize_identity_tuple(serialized: Vec<u8>) -> (Fr, Fr, Fr, Fr) {
)
}
pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Vec<u8> {
pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Result<Vec<u8>> {
let mut serialized: Vec<u8> = Vec::new();
serialized.append(&mut fr_to_bytes_le(&rln_witness.identity_secret));
serialized.append(&mut vec_fr_to_bytes_le(&rln_witness.path_elements));
serialized.append(&mut vec_u8_to_bytes_le(&rln_witness.identity_path_index));
serialized.append(&mut vec_fr_to_bytes_le(&rln_witness.path_elements)?);
serialized.append(&mut vec_u8_to_bytes_le(&rln_witness.identity_path_index)?);
serialized.append(&mut fr_to_bytes_le(&rln_witness.x));
serialized.append(&mut fr_to_bytes_le(&rln_witness.epoch));
serialized.append(&mut fr_to_bytes_le(&rln_witness.rln_identifier));
serialized
Ok(serialized)
}
pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) {
pub fn deserialize_witness(serialized: &[u8]) -> Result<(RLNWitnessInput, usize)> {
let mut all_read: usize = 0;
let (identity_secret, read) = bytes_le_to_fr(&serialized[all_read..]);
all_read += read;
let (path_elements, read) = bytes_le_to_vec_fr(&serialized[all_read..]);
let (path_elements, read) = bytes_le_to_vec_fr(&serialized[all_read..])?;
all_read += read;
let (identity_path_index, read) = bytes_le_to_vec_u8(&serialized[all_read..]);
let (identity_path_index, read) = bytes_le_to_vec_u8(&serialized[all_read..])?;
all_read += read;
let (x, read) = bytes_le_to_fr(&serialized[all_read..]);
@ -126,9 +126,11 @@ pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) {
all_read += read;
// TODO: check rln_identifier against public::RLN_IDENTIFIER
assert_eq!(serialized.len(), all_read);
if serialized.len() != all_read {
return Err(Report::msg("serialized length is not equal to all_read"));
}
(
Ok((
RLNWitnessInput {
identity_secret,
path_elements,
@ -138,7 +140,7 @@ pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) {
rln_identifier,
},
all_read,
)
))
}
// This function deserializes input for kilic's rln generate_proof public API
@ -148,19 +150,19 @@ pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) {
pub fn proof_inputs_to_rln_witness(
tree: &mut PoseidonTree,
serialized: &[u8],
) -> (RLNWitnessInput, usize) {
) -> Result<(RLNWitnessInput, usize)> {
let mut all_read: usize = 0;
let (identity_secret, read) = bytes_le_to_fr(&serialized[all_read..]);
all_read += read;
let id_index = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap());
let id_index = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?);
all_read += 8;
let (epoch, read) = bytes_le_to_fr(&serialized[all_read..]);
all_read += read;
let signal_len = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap());
let signal_len = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?);
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + (signal_len as usize)].to_vec();
@ -173,7 +175,7 @@ pub fn proof_inputs_to_rln_witness(
let rln_identifier = hash_to_field(RLN_IDENTIFIER);
(
Ok((
RLNWitnessInput {
identity_secret,
path_elements,
@ -183,45 +185,48 @@ pub fn proof_inputs_to_rln_witness(
rln_identifier,
},
all_read,
)
))
}
pub fn rln_witness_from_json(input_json_str: &str) -> RLNWitnessInput {
pub fn rln_witness_from_json(input_json_str: &str) -> Result<RLNWitnessInput> {
let input_json: serde_json::Value =
serde_json::from_str(input_json_str).expect("JSON was not well-formatted");
let identity_secret = str_to_fr(&input_json["identity_secret"].to_string(), 10);
let identity_secret = str_to_fr(&input_json["identity_secret"].to_string(), 10)?;
let path_elements = input_json["path_elements"]
.as_array()
.unwrap()
.ok_or(Report::msg("not an array"))?
.iter()
.map(|v| str_to_fr(&v.to_string(), 10))
.collect();
.collect::<Result<_>>()?;
let identity_path_index = input_json["identity_path_index"]
let identity_path_index_array = input_json["identity_path_index"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_u64().unwrap() as u8)
.collect();
.ok_or(Report::msg("not an arrray"))?;
let x = str_to_fr(&input_json["x"].to_string(), 10);
let mut identity_path_index: Vec<u8> = vec![];
let epoch = str_to_fr(&input_json["epoch"].to_string(), 16);
for v in identity_path_index_array {
identity_path_index.push(v.as_u64().ok_or(Report::msg("not a u64 value"))? as u8);
}
let rln_identifier = str_to_fr(&input_json["rln_identifier"].to_string(), 10);
let x = str_to_fr(&input_json["x"].to_string(), 10)?;
let epoch = str_to_fr(&input_json["epoch"].to_string(), 16)?;
let rln_identifier = str_to_fr(&input_json["rln_identifier"].to_string(), 10)?;
// TODO: check rln_identifier against public::RLN_IDENTIFIER
RLNWitnessInput {
Ok(RLNWitnessInput {
identity_secret,
path_elements,
identity_path_index,
x,
epoch,
rln_identifier,
}
})
}
pub fn rln_witness_from_values(
@ -353,8 +358,8 @@ pub fn prepare_prove_input(
id_index: usize,
epoch: Fr,
signal: &[u8],
) -> Vec<u8> {
let signal_len = u64::try_from(signal.len()).unwrap();
) -> Result<Vec<u8>> {
let signal_len = u64::try_from(signal.len())?;
let mut serialized: Vec<u8> = Vec::new();
@ -364,12 +369,12 @@ pub fn prepare_prove_input(
serialized.append(&mut signal_len.to_le_bytes().to_vec());
serialized.append(&mut signal.to_vec());
serialized
Ok(serialized)
}
#[allow(clippy::redundant_clone)]
pub fn prepare_verify_input(proof_data: Vec<u8>, signal: &[u8]) -> Vec<u8> {
let signal_len = u64::try_from(signal.len()).unwrap();
pub fn prepare_verify_input(proof_data: Vec<u8>, signal: &[u8]) -> Result<Vec<u8>> {
let signal_len = u64::try_from(signal.len())?;
let mut serialized: Vec<u8> = Vec::new();
@ -377,7 +382,7 @@ pub fn prepare_verify_input(proof_data: Vec<u8>, signal: &[u8]) -> Vec<u8> {
serialized.append(&mut signal_len.to_le_bytes().to_vec());
serialized.append(&mut signal.to_vec());
serialized
Ok(serialized)
}
///////////////////////////////////////////////////////
@ -533,9 +538,9 @@ pub fn compute_id_secret(
#[derive(Error, Debug)]
pub enum ProofError {
#[error("Error reading circuit key: {0}")]
CircuitKeyError(#[from] std::io::Error),
CircuitKeyError(#[from] Report),
#[error("Error producing witness: {0}")]
WitnessError(color_eyre::Report),
WitnessError(Report),
#[error("Error producing proof: {0}")]
SynthesisError(#[from] SynthesisError),
}
@ -546,20 +551,21 @@ fn calculate_witness_element<E: ark_ec::PairingEngine>(witness: Vec<BigInt>) ->
// convert it to field elements
use num_traits::Signed;
let witness = witness
.into_iter()
.map(|w| {
let w = if w.sign() == num_bigint::Sign::Minus {
// Need to negate the witness element if negative
modulus.into() - w.abs().to_biguint().unwrap()
} else {
w.to_biguint().unwrap()
};
E::Fr::from(w)
})
.collect::<Vec<_>>();
let mut witness_vec = vec![];
for w in witness.into_iter() {
let w = if w.sign() == num_bigint::Sign::Minus {
// Need to negate the witness element if negative
modulus.into()
- w.abs()
.to_biguint()
.ok_or(Report::msg("not a biguint value"))?
} else {
w.to_biguint().ok_or(Report::msg("not a biguint value"))?
};
witness_vec.push(E::Fr::from(w))
}
Ok(witness)
Ok(witness_vec)
}
pub fn generate_proof_with_witness(
@ -570,9 +576,8 @@ pub fn generate_proof_with_witness(
#[cfg(debug_assertions)]
let now = Instant::now();
let full_assignment = calculate_witness_element::<Curve>(witness)
.map_err(ProofError::WitnessError)
.unwrap();
let full_assignment =
calculate_witness_element::<Curve>(witness).map_err(ProofError::WitnessError)?;
#[cfg(debug_assertions)]
println!("witness generation took: {:.2?}", now.elapsed());
@ -594,8 +599,7 @@ pub fn generate_proof_with_witness(
proving_key.1.num_instance_variables,
proving_key.1.num_constraints,
full_assignment.as_slice(),
)
.unwrap();
)?;
#[cfg(debug_assertions)]
println!("proof generation took: {:.2?}", now.elapsed());
@ -603,14 +607,16 @@ pub fn generate_proof_with_witness(
Ok(proof)
}
pub fn inputs_for_witness_calculation(rln_witness: &RLNWitnessInput) -> [(&str, Vec<BigInt>); 6] {
pub fn inputs_for_witness_calculation(
rln_witness: &RLNWitnessInput,
) -> Result<[(&str, Vec<BigInt>); 6]> {
// We confert the path indexes to field elements
// TODO: check if necessary
let mut path_elements = Vec::new();
rln_witness
.path_elements
.iter()
.for_each(|v| path_elements.push(to_bigint(v)));
for v in rln_witness.path_elements.iter() {
path_elements.push(to_bigint(v)?);
}
let mut identity_path_index = Vec::new();
rln_witness
@ -618,20 +624,20 @@ pub fn inputs_for_witness_calculation(rln_witness: &RLNWitnessInput) -> [(&str,
.iter()
.for_each(|v| identity_path_index.push(BigInt::from(*v)));
[
Ok([
(
"identity_secret",
vec![to_bigint(&rln_witness.identity_secret)],
vec![to_bigint(&rln_witness.identity_secret)?],
),
("path_elements", path_elements),
("identity_path_index", identity_path_index),
("x", vec![to_bigint(&rln_witness.x)]),
("epoch", vec![to_bigint(&rln_witness.epoch)]),
("x", vec![to_bigint(&rln_witness.x)?]),
("epoch", vec![to_bigint(&rln_witness.epoch)?]),
(
"rln_identifier",
vec![to_bigint(&rln_witness.rln_identifier)],
vec![to_bigint(&rln_witness.rln_identifier)?],
),
]
])
}
/// Generates a RLN proof
@ -645,7 +651,7 @@ pub fn generate_proof(
proving_key: &(ProvingKey<Curve>, ConstraintMatrices<Fr>),
rln_witness: &RLNWitnessInput,
) -> Result<ArkProof<Curve>, ProofError> {
let inputs = inputs_for_witness_calculation(rln_witness)
let inputs = inputs_for_witness_calculation(rln_witness)?
.into_iter()
.map(|(name, values)| (name.to_string(), values));
@ -736,12 +742,12 @@ pub fn verify_proof(
///
/// Returns a JSON object containing the inputs necessary to calculate
/// the witness with CIRCOM on javascript
pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> serde_json::Value {
pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> Result<serde_json::Value> {
let mut path_elements = Vec::new();
rln_witness
.path_elements
.iter()
.for_each(|v| path_elements.push(to_bigint(v).to_str_radix(10)));
for v in rln_witness.path_elements.iter() {
path_elements.push(to_bigint(v)?.to_str_radix(10));
}
let mut identity_path_index = Vec::new();
rln_witness
@ -750,13 +756,13 @@ pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> serde_json::Value {
.for_each(|v| identity_path_index.push(BigInt::from(*v).to_str_radix(10)));
let inputs = serde_json::json!({
"identity_secret": to_bigint(&rln_witness.identity_secret).to_str_radix(10),
"identity_secret": to_bigint(&rln_witness.identity_secret)?.to_str_radix(10),
"path_elements": path_elements,
"identity_path_index": identity_path_index,
"x": to_bigint(&rln_witness.x).to_str_radix(10),
"epoch": format!("0x{:064x}", to_bigint(&rln_witness.epoch)),
"rln_identifier": to_bigint(&rln_witness.rln_identifier).to_str_radix(10),
"x": to_bigint(&rln_witness.x)?.to_str_radix(10),
"epoch": format!("0x{:064x}", to_bigint(&rln_witness.epoch)?),
"rln_identifier": to_bigint(&rln_witness.rln_identifier)?.to_str_radix(10),
});
inputs
Ok(inputs)
}

View File

@ -10,9 +10,9 @@ use ark_groth16::{ProvingKey, VerifyingKey};
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, Write};
use cfg_if::cfg_if;
use color_eyre::Result;
use num_bigint::BigInt;
use std::io::Cursor;
use std::io::{self, Result};
cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
@ -36,8 +36,8 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809";
///
/// I/O is mostly done using writers and readers implementing `std::io::Write` and `std::io::Read`, respectively.
pub struct RLN<'a> {
proving_key: Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>)>,
verification_key: Result<VerifyingKey<Curve>>,
proving_key: (ProvingKey<Curve>, ConstraintMatrices<Fr>),
verification_key: VerifyingKey<Curve>,
tree: PoseidonTree,
// The witness calculator can't be loaded in zerokit. Since this struct
@ -67,29 +67,29 @@ impl RLN<'_> {
/// let mut rln = RLN::new(tree_height, resources);
/// ```
#[cfg(not(target_arch = "wasm32"))]
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> RLN<'static> {
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> Result<RLN<'static>> {
// We read input
let mut input: Vec<u8> = Vec::new();
input_data.read_to_end(&mut input).unwrap();
input_data.read_to_end(&mut input)?;
let resources_folder = String::from_utf8(input).expect("Found invalid UTF-8");
let resources_folder = String::from_utf8(input)?;
let witness_calculator = circom_from_folder(&resources_folder);
let witness_calculator = circom_from_folder(&resources_folder)?;
let proving_key = zkey_from_folder(&resources_folder);
let verification_key = vk_from_folder(&resources_folder);
let proving_key = zkey_from_folder(&resources_folder)?;
let verification_key = vk_from_folder(&resources_folder)?;
// We compute a default empty tree
let tree = PoseidonTree::default(tree_height);
RLN {
Ok(RLN {
witness_calculator,
proving_key,
verification_key,
tree,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData,
}
})
}
/// Creates a new RLN object by passing circuit resources as byte vectors.
@ -130,17 +130,17 @@ impl RLN<'_> {
#[cfg(not(target_arch = "wasm32"))] circom_vec: Vec<u8>,
zkey_vec: Vec<u8>,
vk_vec: Vec<u8>,
) -> RLN<'static> {
) -> Result<RLN<'static>> {
#[cfg(not(target_arch = "wasm32"))]
let witness_calculator = circom_from_raw(circom_vec);
let witness_calculator = circom_from_raw(circom_vec)?;
let proving_key = zkey_from_raw(&zkey_vec);
let verification_key = vk_from_raw(&vk_vec, &zkey_vec);
let proving_key = zkey_from_raw(&zkey_vec)?;
let verification_key = vk_from_raw(&vk_vec, &zkey_vec)?;
// We compute a default empty tree
let tree = PoseidonTree::default(tree_height);
RLN {
Ok(RLN {
#[cfg(not(target_arch = "wasm32"))]
witness_calculator,
proving_key,
@ -148,7 +148,7 @@ impl RLN<'_> {
tree,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData,
}
})
}
////////////////////////////////////////////////////////
@ -160,7 +160,7 @@ impl RLN<'_> {
///
/// Input values are:
/// - `tree_height`: the height of the Merkle tree.
pub fn set_tree(&mut self, tree_height: usize) -> io::Result<()> {
pub fn set_tree(&mut self, tree_height: usize) -> Result<()> {
// We compute a default empty tree of desired height
self.tree = PoseidonTree::default(tree_height);
@ -187,7 +187,7 @@ impl RLN<'_> {
/// let mut buffer = Cursor::new(serialize_field_element(id_commitment));
/// rln.set_leaf(id_index, &mut buffer).unwrap();
/// ```
pub fn set_leaf<R: Read>(&mut self, index: usize, mut input_data: R) -> io::Result<()> {
pub fn set_leaf<R: Read>(&mut self, index: usize, mut input_data: R) -> Result<()> {
// We read input
let mut leaf_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut leaf_byte)?;
@ -229,12 +229,12 @@ impl RLN<'_> {
/// let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
/// rln.set_leaves_from(index, &mut buffer).unwrap();
/// ```
pub fn set_leaves_from<R: Read>(&mut self, index: usize, mut input_data: R) -> io::Result<()> {
pub fn set_leaves_from<R: Read>(&mut self, index: usize, mut input_data: R) -> Result<()> {
// We read input
let mut leaves_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut leaves_byte)?;
let (leaves, _) = bytes_le_to_vec_fr(&leaves_byte);
let (leaves, _) = bytes_le_to_vec_fr(&leaves_byte)?;
// We set the leaves
self.tree.set_range(index, leaves)
@ -246,7 +246,7 @@ impl RLN<'_> {
///
/// Input values are:
/// - `input_data`: a reader for the serialization of multiple leaf values (serialization done with [`rln::utils::vec_fr_to_bytes_le`](crate::utils::vec_fr_to_bytes_le))
pub fn init_tree_with_leaves<R: Read>(&mut self, input_data: R) -> io::Result<()> {
pub fn init_tree_with_leaves<R: Read>(&mut self, input_data: R) -> Result<()> {
// reset the tree
// NOTE: this requires the tree to be initialized with the correct height initially
// TODO: accept tree_height as a parameter and initialize the tree with that height
@ -295,7 +295,7 @@ impl RLN<'_> {
/// let mut buffer = Cursor::new(fr_to_bytes_le(&id_commitment));
/// rln.set_next_leaf(&mut buffer).unwrap();
/// ```
pub fn set_next_leaf<R: Read>(&mut self, mut input_data: R) -> io::Result<()> {
pub fn set_next_leaf<R: Read>(&mut self, mut input_data: R) -> Result<()> {
// We read input
let mut leaf_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut leaf_byte)?;
@ -320,7 +320,7 @@ impl RLN<'_> {
/// let index = 10;
/// rln.delete_leaf(index).unwrap();
/// ```
pub fn delete_leaf(&mut self, index: usize) -> io::Result<()> {
pub fn delete_leaf(&mut self, index: usize) -> Result<()> {
self.tree.delete(index)?;
Ok(())
}
@ -338,7 +338,7 @@ impl RLN<'_> {
/// rln.get_root(&mut buffer).unwrap();
/// let (root, _) = bytes_le_to_fr(&buffer.into_inner());
/// ```
pub fn get_root<W: Write>(&self, mut output_data: W) -> io::Result<()> {
pub fn get_root<W: Write>(&self, mut output_data: W) -> Result<()> {
let root = self.tree.root();
output_data.write_all(&fr_to_bytes_le(&root))?;
@ -366,13 +366,13 @@ impl RLN<'_> {
/// let (path_elements, read) = bytes_le_to_vec_fr(&buffer_inner);
/// let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec());
/// ```
pub fn get_proof<W: Write>(&self, index: usize, mut output_data: W) -> io::Result<()> {
pub fn get_proof<W: Write>(&self, index: usize, mut output_data: W) -> Result<()> {
let merkle_proof = self.tree.proof(index).expect("proof should exist");
let path_elements = merkle_proof.get_path_elements();
let identity_path_index = merkle_proof.get_path_index();
output_data.write_all(&vec_fr_to_bytes_le(&path_elements))?;
output_data.write_all(&vec_u8_to_bytes_le(&identity_path_index))?;
output_data.write_all(&vec_fr_to_bytes_le(&path_elements)?)?;
output_data.write_all(&vec_u8_to_bytes_le(&identity_path_index)?)?;
Ok(())
}
@ -406,11 +406,11 @@ impl RLN<'_> {
&mut self,
mut input_data: R,
mut output_data: W,
) -> io::Result<()> {
) -> Result<()> {
// We read input RLN witness and we deserialize it
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let (rln_witness, _) = deserialize_witness(&serialized);
let (rln_witness, _) = deserialize_witness(&serialized)?;
/*
if self.witness_calculator.is_none() {
@ -418,15 +418,10 @@ impl RLN<'_> {
}
*/
let proof = generate_proof(
self.witness_calculator,
self.proving_key.as_ref().unwrap(),
&rln_witness,
)
.unwrap();
let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?;
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
proof.serialize(&mut output_data).unwrap();
proof.serialize(&mut output_data)?;
Ok(())
}
@ -466,22 +461,17 @@ impl RLN<'_> {
///
/// assert!(verified);
/// ```
pub fn verify<R: Read>(&self, mut input_data: R) -> io::Result<bool> {
pub fn verify<R: Read>(&self, mut input_data: R) -> Result<bool> {
// Input data is serialized for Curve as:
// serialized_proof (compressed, 4*32 bytes) || serialized_proof_values (6*32 bytes), i.e.
// [ proof<128> | root<32> | epoch<32> | share_x<32> | share_y<32> | nullifier<32> | rln_identifier<32> ]
let mut input_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut input_byte)?;
let proof = ArkProof::deserialize(&mut Cursor::new(&input_byte[..128])).unwrap();
let proof = ArkProof::deserialize(&mut Cursor::new(&input_byte[..128]))?;
let (proof_values, _) = deserialize_proof_values(&input_byte[128..]);
let verified = verify_proof(
self.verification_key.as_ref().unwrap(),
&proof,
&proof_values,
)
.unwrap();
let verified = verify_proof(&self.verification_key, &proof, &proof_values)?;
Ok(verified)
}
@ -537,23 +527,18 @@ impl RLN<'_> {
&mut self,
mut input_data: R,
mut output_data: W,
) -> io::Result<()> {
) -> Result<()> {
// We read input RLN witness and we deserialize it
let mut witness_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut witness_byte)?;
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte);
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?;
let proof_values = proof_values_from_witness(&rln_witness);
let proof = generate_proof(
self.witness_calculator,
self.proving_key.as_ref().unwrap(),
&rln_witness,
)
.unwrap();
let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?;
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
// This proof is compressed, i.e. 128 bytes long
proof.serialize(&mut output_data).unwrap();
proof.serialize(&mut output_data)?;
output_data.write_all(&serialize_proof_values(&proof_values))?;
Ok(())
@ -570,17 +555,15 @@ impl RLN<'_> {
calculated_witness: Vec<BigInt>,
rln_witness_vec: Vec<u8>,
mut output_data: W,
) -> io::Result<()> {
let (rln_witness, _) = deserialize_witness(&rln_witness_vec[..]);
) -> Result<()> {
let (rln_witness, _) = deserialize_witness(&rln_witness_vec[..])?;
let proof_values = proof_values_from_witness(&rln_witness);
let proof =
generate_proof_with_witness(calculated_witness, self.proving_key.as_ref().unwrap())
.unwrap();
let proof = generate_proof_with_witness(calculated_witness, &self.proving_key).unwrap();
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
// This proof is compressed, i.e. 128 bytes long
proof.serialize(&mut output_data).unwrap();
proof.serialize(&mut output_data)?;
output_data.write_all(&serialize_proof_values(&proof_values))?;
Ok(())
}
@ -612,27 +595,22 @@ impl RLN<'_> {
///
/// assert!(verified);
/// ```
pub fn verify_rln_proof<R: Read>(&self, mut input_data: R) -> io::Result<bool> {
pub fn verify_rln_proof<R: Read>(&self, mut input_data: R) -> Result<bool> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let mut all_read = 0;
let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec())).unwrap();
let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec()))?;
all_read += 128;
let (proof_values, read) = deserialize_proof_values(&serialized[all_read..]);
all_read += read;
let signal_len =
u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()) as usize;
u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?) as usize;
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + signal_len].to_vec();
let verified = verify_proof(
self.verification_key.as_ref().unwrap(),
&proof,
&proof_values,
)
.unwrap();
let verified = verify_proof(&self.verification_key, &proof, &proof_values)?;
// Consistency checks to counter proof tampering
let x = hash_to_field(&signal);
@ -693,31 +671,22 @@ impl RLN<'_> {
///
/// assert!(verified);
/// ```
pub fn verify_with_roots<R: Read>(
&self,
mut input_data: R,
mut roots_data: R,
) -> io::Result<bool> {
pub fn verify_with_roots<R: Read>(&self, mut input_data: R, mut roots_data: R) -> Result<bool> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let mut all_read = 0;
let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec())).unwrap();
let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec()))?;
all_read += 128;
let (proof_values, read) = deserialize_proof_values(&serialized[all_read..]);
all_read += read;
let signal_len =
u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()) as usize;
u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?) as usize;
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + signal_len].to_vec();
let verified = verify_proof(
self.verification_key.as_ref().unwrap(),
&proof,
&proof_values,
)
.unwrap();
let verified = verify_proof(&self.verification_key, &proof, &proof_values)?;
// First consistency checks to counter proof tampering
let x = hash_to_field(&signal);
@ -783,7 +752,7 @@ impl RLN<'_> {
/// // We deserialize the keygen output
/// let (identity_secret_hash, id_commitment) = deserialize_identity_pair(buffer.into_inner());
/// ```
pub fn key_gen<W: Write>(&self, mut output_data: W) -> io::Result<()> {
pub fn key_gen<W: Write>(&self, mut output_data: W) -> Result<()> {
let (identity_secret_hash, id_commitment) = keygen();
output_data.write_all(&fr_to_bytes_le(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_le(&id_commitment))?;
@ -813,7 +782,7 @@ impl RLN<'_> {
/// // We deserialize the keygen output
/// let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) = deserialize_identity_tuple(buffer.into_inner());
/// ```
pub fn extended_key_gen<W: Write>(&self, mut output_data: W) -> io::Result<()> {
pub fn extended_key_gen<W: Write>(&self, mut output_data: W) -> Result<()> {
let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) =
extended_keygen();
output_data.write_all(&fr_to_bytes_le(&identity_trapdoor))?;
@ -852,7 +821,7 @@ impl RLN<'_> {
&self,
mut input_data: R,
mut output_data: W,
) -> io::Result<()> {
) -> Result<()> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@ -895,7 +864,7 @@ impl RLN<'_> {
&self,
mut input_data: R,
mut output_data: W,
) -> io::Result<()> {
) -> Result<()> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@ -946,7 +915,7 @@ impl RLN<'_> {
mut input_proof_data_1: R,
mut input_proof_data_2: R,
mut output_data: W,
) -> io::Result<()> {
) -> Result<()> {
// We deserialize the two proofs and we get the corresponding RLNProofValues objects
let mut serialized: Vec<u8> = Vec::new();
input_proof_data_1.read_to_end(&mut serialized)?;
@ -990,11 +959,11 @@ impl RLN<'_> {
/// - `input_data`: a reader for the serialization of `[ identity_secret<32> | id_index<8> | epoch<32> | signal_len<8> | signal<var> ]`
///
/// The function returns the corresponding [`RLNWitnessInput`](crate::protocol::RLNWitnessInput) object serialized using [`rln::protocol::serialize_witness`](crate::protocol::serialize_witness)).
pub fn get_serialized_rln_witness<R: Read>(&mut self, mut input_data: R) -> Vec<u8> {
pub fn get_serialized_rln_witness<R: Read>(&mut self, mut input_data: R) -> Result<Vec<u8>> {
// We read input RLN witness and we deserialize it
let mut witness_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut witness_byte).unwrap();
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte);
input_data.read_to_end(&mut witness_byte)?;
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?;
serialize_witness(&rln_witness)
}
@ -1005,12 +974,9 @@ impl RLN<'_> {
/// - `serialized_witness`: the byte serialization of a [`RLNWitnessInput`](crate::protocol::RLNWitnessInput) object (serialization done with [`rln::protocol::serialize_witness`](crate::protocol::serialize_witness)).
///
/// The function returns the corresponding JSON encoding of the input [`RLNWitnessInput`](crate::protocol::RLNWitnessInput) object.
pub fn get_rln_witness_json(
&mut self,
serialized_witness: &[u8],
) -> io::Result<serde_json::Value> {
let (rln_witness, _) = deserialize_witness(serialized_witness);
Ok(get_json_inputs(&rln_witness))
pub fn get_rln_witness_json(&mut self, serialized_witness: &[u8]) -> Result<serde_json::Value> {
let (rln_witness, _) = deserialize_witness(serialized_witness)?;
get_json_inputs(&rln_witness)
}
}
@ -1019,7 +985,7 @@ impl Default for RLN<'_> {
fn default() -> Self {
let tree_height = TEST_TREE_HEIGHT;
let buffer = Cursor::new(TEST_RESOURCES_FOLDER);
Self::new(tree_height, buffer)
Self::new(tree_height, buffer).unwrap()
}
}
@ -1045,7 +1011,7 @@ impl Default for RLN<'_> {
/// // We deserialize the keygen output
/// let field_element = deserialize_field_element(output_buffer.into_inner());
/// ```
pub fn hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> io::Result<()> {
pub fn hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> Result<()> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@ -1078,11 +1044,11 @@ pub fn hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> io::Res
/// // We deserialize the hash output
/// let hash_result = deserialize_field_element(output_buffer.into_inner());
/// ```
pub fn poseidon_hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> io::Result<()> {
pub fn poseidon_hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> Result<()> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let (inputs, _) = bytes_le_to_vec_fr(&serialized);
let (inputs, _) = bytes_le_to_vec_fr(&serialized)?;
let hash = utils_poseidon_hash(inputs.as_ref());
output_data.write_all(&fr_to_bytes_le(&hash))?;
@ -1110,7 +1076,7 @@ mod test {
// We create a new tree
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We first add leaves one by one specifying the index
for (i, leaf) in leaves.iter().enumerate() {
@ -1149,7 +1115,7 @@ mod test {
rln.set_tree(tree_height).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@ -1205,10 +1171,10 @@ mod test {
// We create a new tree
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@ -1222,11 +1188,11 @@ mod test {
// `init_tree_with_leaves` resets the tree to the height it was initialized with, using `set_tree`
// We add leaves in a batch starting from index 0..set_index
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[0..set_index]));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[0..set_index]).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We add the remaining n leaves in a batch starting from index m
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[set_index..]));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[set_index..]).unwrap());
rln.set_leaves_from(set_index, &mut buffer).unwrap();
// We check if number of leaves set is consistent
@ -1259,6 +1225,7 @@ mod test {
assert_eq!(root_batch_with_init, root_single_additions);
}
#[allow(unused_must_use)]
#[test]
// This test checks if `set_leaves_from` throws an error when the index is out of bounds
fn test_set_leaves_bad_index() {
@ -1275,7 +1242,7 @@ mod test {
// We create a new tree
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// Get root of empty tree
let mut buffer = Cursor::new(Vec::<u8>::new());
@ -1283,7 +1250,7 @@ mod test {
let (root_empty, _) = bytes_le_to_fr(&buffer.into_inner());
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.set_leaves_from(bad_index, &mut buffer)
.expect_err("Should throw an error");
@ -1304,25 +1271,21 @@ mod test {
let tree_height = TEST_TREE_HEIGHT;
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// Note: we only test Groth16 proof generation, so we ignore setting the tree in the RLN object
let rln_witness = random_rln_witness(tree_height);
let proof_values = proof_values_from_witness(&rln_witness);
// We compute a Groth16 proof
let mut input_buffer = Cursor::new(serialize_witness(&rln_witness));
let mut input_buffer = Cursor::new(serialize_witness(&rln_witness).unwrap());
let mut output_buffer = Cursor::new(Vec::<u8>::new());
rln.prove(&mut input_buffer, &mut output_buffer).unwrap();
let serialized_proof = output_buffer.into_inner();
// Before checking public verify API, we check that the (deserialized) proof generated by prove is actually valid
let proof = ArkProof::deserialize(&mut Cursor::new(&serialized_proof)).unwrap();
let verified = verify_proof(
&rln.verification_key.as_ref().unwrap(),
&proof,
&proof_values,
);
let verified = verify_proof(&rln.verification_key, &proof, &proof_values);
assert!(verified.unwrap());
// We prepare the input to prove API, consisting of serialized_proof (compressed, 4*32 bytes) || serialized_proof_values (6*32 bytes)
@ -1352,10 +1315,10 @@ mod test {
// We create a new RLN instance
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// Generate identity pair
@ -1417,10 +1380,10 @@ mod test {
// We create a new RLN instance
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// Generate identity pair
@ -1453,12 +1416,13 @@ mod test {
// We read input RLN witness and we deserialize it
let mut witness_byte: Vec<u8> = Vec::new();
input_buffer.read_to_end(&mut witness_byte).unwrap();
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut rln.tree, &witness_byte);
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut rln.tree, &witness_byte).unwrap();
let serialized_witness = serialize_witness(&rln_witness);
let serialized_witness = serialize_witness(&rln_witness).unwrap();
// Calculate witness outside zerokit (simulating what JS is doing)
let inputs = inputs_for_witness_calculation(&rln_witness)
.unwrap()
.into_iter()
.map(|(name, values)| (name.to_string(), values));
let calculated_witness = rln
@ -1471,7 +1435,7 @@ mod test {
let calculated_witness_vec: Vec<BigInt> = calculated_witness
.into_iter()
.map(|v| to_bigint(&v))
.map(|v| to_bigint(&v).unwrap())
.collect();
// Generating the proof
@ -1513,10 +1477,10 @@ mod test {
// We create a new RLN instance
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
rln.init_tree_with_leaves(&mut buffer).unwrap();
// Generate identity pair
@ -1600,7 +1564,7 @@ mod test {
// We create a new RLN instance
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// Generate identity pair
let (identity_secret_hash, id_commitment) = keygen();

View File

@ -2,13 +2,14 @@
use crate::circuit::Fr;
use ark_ff::PrimeField;
use color_eyre::{Report, Result};
use num_bigint::{BigInt, BigUint};
use num_traits::Num;
use std::iter::Extend;
pub fn to_bigint(el: &Fr) -> BigInt {
let res: BigUint = (*el).try_into().unwrap();
res.try_into().unwrap()
pub fn to_bigint(el: &Fr) -> Result<BigInt> {
let res: BigUint = (*el).try_into()?;
Ok(res.into())
}
pub fn fr_byte_size() -> usize {
@ -16,8 +17,10 @@ pub fn fr_byte_size() -> usize {
(mbs + 64 - (mbs % 64)) / 8
}
pub fn str_to_fr(input: &str, radix: u32) -> Fr {
assert!((radix == 10) || (radix == 16));
pub fn str_to_fr(input: &str, radix: u32) -> Result<Fr> {
if !(radix == 10 || radix == 16) {
return Err(Report::msg("wrong radix"));
}
// We remove any quote present and we trim
let single_quote: char = '\"';
@ -25,16 +28,10 @@ pub fn str_to_fr(input: &str, radix: u32) -> Fr {
input_clean = input_clean.trim().to_string();
if radix == 10 {
BigUint::from_str_radix(&input_clean, radix)
.unwrap()
.try_into()
.unwrap()
Ok(BigUint::from_str_radix(&input_clean, radix)?.try_into()?)
} else {
input_clean = input_clean.replace("0x", "");
BigUint::from_str_radix(&input_clean, radix)
.unwrap()
.try_into()
.unwrap()
Ok(BigUint::from_str_radix(&input_clean, radix)?.try_into()?)
}
}
@ -75,72 +72,73 @@ pub fn fr_to_bytes_be(input: &Fr) -> Vec<u8> {
res
}
pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec<u8> {
pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Result<Vec<u8>> {
let mut bytes: Vec<u8> = Vec::new();
//We store the vector length
bytes.extend(u64::try_from(input.len()).unwrap().to_le_bytes().to_vec());
bytes.extend(u64::try_from(input.len())?.to_le_bytes().to_vec());
// We store each element
input.iter().for_each(|el| bytes.extend(fr_to_bytes_le(el)));
bytes
Ok(bytes)
}
pub fn vec_fr_to_bytes_be(input: &[Fr]) -> Vec<u8> {
pub fn vec_fr_to_bytes_be(input: &[Fr]) -> Result<Vec<u8>> {
let mut bytes: Vec<u8> = Vec::new();
//We store the vector length
bytes.extend(u64::try_from(input.len()).unwrap().to_be_bytes().to_vec());
bytes.extend(u64::try_from(input.len())?.to_be_bytes().to_vec());
// We store each element
input.iter().for_each(|el| bytes.extend(fr_to_bytes_be(el)));
bytes
Ok(bytes)
}
pub fn vec_u8_to_bytes_le(input: &[u8]) -> Vec<u8> {
pub fn vec_u8_to_bytes_le(input: &[u8]) -> Result<Vec<u8>> {
let mut bytes: Vec<u8> = Vec::new();
//We store the vector length
bytes.extend(u64::try_from(input.len()).unwrap().to_le_bytes().to_vec());
bytes.extend(u64::try_from(input.len())?.to_le_bytes().to_vec());
bytes.extend(input);
bytes
Ok(bytes)
}
pub fn vec_u8_to_bytes_be(input: Vec<u8>) -> Vec<u8> {
let mut bytes: Vec<u8> = Vec::new();
pub fn vec_u8_to_bytes_be(input: Vec<u8>) -> Result<Vec<u8>> {
//We store the vector length
bytes.extend(u64::try_from(input.len()).unwrap().to_be_bytes().to_vec());
let mut bytes: Vec<u8> = u64::try_from(input.len())?.to_be_bytes().to_vec();
bytes.extend(input);
bytes
Ok(bytes)
}
pub fn bytes_le_to_vec_u8(input: &[u8]) -> (Vec<u8>, usize) {
pub fn bytes_le_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize)> {
let mut read: usize = 0;
let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize;
let len = u64::from_le_bytes(input[0..8].try_into()?) as usize;
read += 8;
let res = input[8..8 + len].to_vec();
read += res.len();
(res, read)
Ok((res, read))
}
pub fn bytes_be_to_vec_u8(input: &[u8]) -> (Vec<u8>, usize) {
pub fn bytes_be_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize)> {
let mut read: usize = 0;
let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize;
let len = u64::from_be_bytes(input[0..8].try_into()?) as usize;
read += 8;
let res = input[8..8 + len].to_vec();
read += res.len();
(res, read)
Ok((res, read))
}
pub fn bytes_le_to_vec_fr(input: &[u8]) -> (Vec<Fr>, usize) {
pub fn bytes_le_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize)> {
let mut read: usize = 0;
let mut res: Vec<Fr> = Vec::new();
let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize;
let len = u64::from_le_bytes(input[0..8].try_into()?) as usize;
read += 8;
let el_size = fr_byte_size();
@ -150,14 +148,14 @@ pub fn bytes_le_to_vec_fr(input: &[u8]) -> (Vec<Fr>, usize) {
read += el_size;
}
(res, read)
Ok((res, read))
}
pub fn bytes_be_to_vec_fr(input: &[u8]) -> (Vec<Fr>, usize) {
pub fn bytes_be_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize)> {
let mut read: usize = 0;
let mut res: Vec<Fr> = Vec::new();
let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize;
let len = u64::from_be_bytes(input[0..8].try_into()?) as usize;
read += 8;
let el_size = fr_byte_size();
@ -167,7 +165,7 @@ pub fn bytes_be_to_vec_fr(input: &[u8]) -> (Vec<Fr>, usize) {
read += el_size;
}
(res, read)
Ok((res, read))
}
/* Old conversion utilities between different libraries data types

View File

@ -78,7 +78,7 @@ mod test {
assert!(success, "set tree call failed");
// We add leaves in a batch into the tree
let leaves_ser = vec_fr_to_bytes_le(&leaves);
let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap();
let input_buffer = &Buffer::from(leaves_ser.as_ref());
let success = init_tree_with_leaves(rln_pointer, input_buffer);
assert!(success, "init tree with leaves call failed");
@ -153,7 +153,7 @@ mod test {
let set_index = rng.gen_range(0..no_of_leaves) as usize;
// We add leaves in a batch into the tree
let leaves_ser = vec_fr_to_bytes_le(&leaves);
let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap();
let input_buffer = &Buffer::from(leaves_ser.as_ref());
let success = init_tree_with_leaves(rln_pointer, input_buffer);
assert!(success, "init tree with leaves call failed");
@ -170,13 +170,13 @@ mod test {
// `init_tree_with_leaves` resets the tree to the height it was initialized with, using `set_tree`
// We add leaves in a batch starting from index 0..set_index
let leaves_m = vec_fr_to_bytes_le(&leaves[0..set_index]);
let leaves_m = vec_fr_to_bytes_le(&leaves[0..set_index]).unwrap();
let buffer = &Buffer::from(leaves_m.as_ref());
let success = init_tree_with_leaves(rln_pointer, buffer);
assert!(success, "init tree with leaves call failed");
// We add the remaining n leaves in a batch starting from index set_index
let leaves_n = vec_fr_to_bytes_le(&leaves[set_index..]);
let leaves_n = vec_fr_to_bytes_le(&leaves[set_index..]).unwrap();
let buffer = &Buffer::from(leaves_n.as_ref());
let success = set_leaves_from(rln_pointer, set_index, buffer);
assert!(success, "set leaves from call failed");
@ -248,7 +248,7 @@ mod test {
let (root_empty, _) = bytes_le_to_fr(&result_data);
// We add leaves in a batch into the tree
let leaves = vec_fr_to_bytes_le(&leaves);
let leaves = vec_fr_to_bytes_le(&leaves).unwrap();
let buffer = &Buffer::from(leaves.as_ref());
let success = set_leaves_from(rln_pointer, bad_index, buffer);
assert!(!success, "set leaves from call succeeded");
@ -303,71 +303,86 @@ mod test {
let output_buffer = unsafe { output_buffer.assume_init() };
let result_data = <&[u8]>::from(&output_buffer).to_vec();
let (path_elements, read) = bytes_le_to_vec_fr(&result_data);
let (identity_path_index, _) = bytes_le_to_vec_u8(&result_data[read..].to_vec());
let (path_elements, read) = bytes_le_to_vec_fr(&result_data).unwrap();
let (identity_path_index, _) = bytes_le_to_vec_u8(&result_data[read..].to_vec()).unwrap();
// We check correct computation of the path and indexes
let mut expected_path_elements = vec![
str_to_fr(
"0x0000000000000000000000000000000000000000000000000000000000000000",
16,
),
)
.unwrap(),
str_to_fr(
"0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864",
16,
),
)
.unwrap(),
str_to_fr(
"0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1",
16,
),
)
.unwrap(),
str_to_fr(
"0x18f43331537ee2af2e3d758d50f72106467c6eea50371dd528d57eb2b856d238",
16,
),
)
.unwrap(),
str_to_fr(
"0x07f9d837cb17b0d36320ffe93ba52345f1b728571a568265caac97559dbc952a",
16,
),
)
.unwrap(),
str_to_fr(
"0x2b94cf5e8746b3f5c9631f4c5df32907a699c58c94b2ad4d7b5cec1639183f55",
16,
),
)
.unwrap(),
str_to_fr(
"0x2dee93c5a666459646ea7d22cca9e1bcfed71e6951b953611d11dda32ea09d78",
16,
),
)
.unwrap(),
str_to_fr(
"0x078295e5a22b84e982cf601eb639597b8b0515a88cb5ac7fa8a4aabe3c87349d",
16,
),
)
.unwrap(),
str_to_fr(
"0x2fa5e5f18f6027a6501bec864564472a616b2e274a41211a444cbe3a99f3cc61",
16,
),
)
.unwrap(),
str_to_fr(
"0x0e884376d0d8fd21ecb780389e941f66e45e7acce3e228ab3e2156a614fcd747",
16,
),
)
.unwrap(),
str_to_fr(
"0x1b7201da72494f1e28717ad1a52eb469f95892f957713533de6175e5da190af2",
16,
),
)
.unwrap(),
str_to_fr(
"0x1f8d8822725e36385200c0b201249819a6e6e1e4650808b5bebc6bface7d7636",
16,
),
)
.unwrap(),
str_to_fr(
"0x2c5d82f66c914bafb9701589ba8cfcfb6162b0a12acf88a8d0879a0471b5f85a",
16,
),
)
.unwrap(),
str_to_fr(
"0x14c54148a0940bb820957f5adf3fa1134ef5c4aaa113f4646458f270e0bfbfd0",
16,
),
)
.unwrap(),
str_to_fr(
"0x190d33b12f986f961e10c0ee44d8b9af11be25588cad89d416118e4bf4ebe80c",
16,
),
)
.unwrap(),
];
let mut expected_identity_path_index: Vec<u8> =
@ -379,19 +394,23 @@ mod test {
str_to_fr(
"0x22f98aa9ce704152ac17354914ad73ed1167ae6596af510aa5b3649325e06c92",
16,
),
)
.unwrap(),
str_to_fr(
"0x2a7c7c9b6ce5880b9f6f228d72bf6a575a526f29c66ecceef8b753d38bba7323",
16,
),
)
.unwrap(),
str_to_fr(
"0x2e8186e558698ec1c67af9c14d463ffc470043c9c2988b954d75dd643f36b992",
16,
),
)
.unwrap(),
str_to_fr(
"0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f",
16,
),
)
.unwrap(),
]);
expected_identity_path_index.append(&mut vec![0, 0, 0, 0]);
}
@ -400,7 +419,8 @@ mod test {
expected_path_elements.append(&mut vec![str_to_fr(
"0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca",
16,
)]);
)
.unwrap()]);
expected_identity_path_index.append(&mut vec![0]);
}
@ -439,7 +459,7 @@ mod test {
let proof_values = proof_values_from_witness(&rln_witness);
// We prepare id_commitment and we set the leaf at provided index
let rln_witness_ser = serialize_witness(&rln_witness);
let rln_witness_ser = serialize_witness(&rln_witness).unwrap();
let input_buffer = &Buffer::from(rln_witness_ser.as_ref());
let mut output_buffer = MaybeUninit::<Buffer>::uninit();
let now = Instant::now();
@ -569,7 +589,7 @@ mod test {
let rln_pointer = unsafe { &mut *rln_pointer.assume_init() };
// We add leaves in a batch into the tree
let leaves_ser = vec_fr_to_bytes_le(&leaves);
let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap();
let input_buffer = &Buffer::from(leaves_ser.as_ref());
let success = init_tree_with_leaves(rln_pointer, input_buffer);
assert!(success, "init tree with leaves call failed");
@ -654,7 +674,7 @@ mod test {
let rln_pointer = unsafe { &mut *rln_pointer.assume_init() };
// We add leaves in a batch into the tree
let leaves_ser = vec_fr_to_bytes_le(&leaves);
let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap();
let input_buffer = &Buffer::from(leaves_ser.as_ref());
let success = init_tree_with_leaves(rln_pointer, input_buffer);
assert!(success, "set leaves call failed");
@ -957,16 +977,15 @@ mod test {
assert_eq!(
identity_secret_hash,
expected_identity_secret_hash_seed_bytes
expected_identity_secret_hash_seed_bytes.unwrap()
);
assert_eq!(id_commitment, expected_id_commitment_seed_bytes);
assert_eq!(id_commitment, expected_id_commitment_seed_bytes.unwrap());
}
#[test]
// Tests hash to field using FFI APIs
fn test_seeded_extended_keygen_ffi() {
let tree_height = TEST_TREE_HEIGHT;
// We create a RLN instance
let mut rln_pointer = MaybeUninit::<*mut RLN>::uninit();
let input_buffer = &Buffer::from(TEST_RESOURCES_FOLDER.as_bytes());
@ -1004,13 +1023,19 @@ mod test {
16,
);
assert_eq!(identity_trapdoor, expected_identity_trapdoor_seed_bytes);
assert_eq!(identity_nullifier, expected_identity_nullifier_seed_bytes);
assert_eq!(
identity_trapdoor,
expected_identity_trapdoor_seed_bytes.unwrap()
);
assert_eq!(
identity_nullifier,
expected_identity_nullifier_seed_bytes.unwrap()
);
assert_eq!(
identity_secret_hash,
expected_identity_secret_hash_seed_bytes
expected_identity_secret_hash_seed_bytes.unwrap()
);
assert_eq!(id_commitment, expected_id_commitment_seed_bytes);
assert_eq!(id_commitment, expected_id_commitment_seed_bytes.unwrap());
}
#[test]
@ -1045,7 +1070,7 @@ mod test {
for _ in 0..number_of_inputs {
inputs.push(Fr::rand(&mut rng));
}
let inputs_ser = vec_fr_to_bytes_le(&inputs);
let inputs_ser = vec_fr_to_bytes_le(&inputs).unwrap();
let input_buffer = &Buffer::from(inputs_ser.as_ref());
let expected_hash = utils_poseidon_hash(inputs.as_ref());

View File

@ -184,6 +184,7 @@ mod test {
"0x1984f2e01184aef5cb974640898a5f5c25556554e2b06d99d4841badb8b198cd",
16
)
.unwrap()
);
} else if TEST_TREE_HEIGHT == 19 {
assert_eq!(
@ -192,6 +193,7 @@ mod test {
"0x219ceb53f2b1b7a6cf74e80d50d44d68ecb4a53c6cc65b25593c8d56343fb1fe",
16
)
.unwrap()
);
} else if TEST_TREE_HEIGHT == 20 {
assert_eq!(
@ -200,6 +202,7 @@ mod test {
"0x21947ffd0bce0c385f876e7c97d6a42eec5b1fe935aab2f01c1f8a8cbcc356d2",
16
)
.unwrap()
);
}
@ -213,63 +216,78 @@ mod test {
str_to_fr(
"0x0000000000000000000000000000000000000000000000000000000000000000",
16,
),
)
.unwrap(),
str_to_fr(
"0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864",
16,
),
)
.unwrap(),
str_to_fr(
"0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1",
16,
),
)
.unwrap(),
str_to_fr(
"0x18f43331537ee2af2e3d758d50f72106467c6eea50371dd528d57eb2b856d238",
16,
),
)
.unwrap(),
str_to_fr(
"0x07f9d837cb17b0d36320ffe93ba52345f1b728571a568265caac97559dbc952a",
16,
),
)
.unwrap(),
str_to_fr(
"0x2b94cf5e8746b3f5c9631f4c5df32907a699c58c94b2ad4d7b5cec1639183f55",
16,
),
)
.unwrap(),
str_to_fr(
"0x2dee93c5a666459646ea7d22cca9e1bcfed71e6951b953611d11dda32ea09d78",
16,
),
)
.unwrap(),
str_to_fr(
"0x078295e5a22b84e982cf601eb639597b8b0515a88cb5ac7fa8a4aabe3c87349d",
16,
),
)
.unwrap(),
str_to_fr(
"0x2fa5e5f18f6027a6501bec864564472a616b2e274a41211a444cbe3a99f3cc61",
16,
),
)
.unwrap(),
str_to_fr(
"0x0e884376d0d8fd21ecb780389e941f66e45e7acce3e228ab3e2156a614fcd747",
16,
),
)
.unwrap(),
str_to_fr(
"0x1b7201da72494f1e28717ad1a52eb469f95892f957713533de6175e5da190af2",
16,
),
)
.unwrap(),
str_to_fr(
"0x1f8d8822725e36385200c0b201249819a6e6e1e4650808b5bebc6bface7d7636",
16,
),
)
.unwrap(),
str_to_fr(
"0x2c5d82f66c914bafb9701589ba8cfcfb6162b0a12acf88a8d0879a0471b5f85a",
16,
),
)
.unwrap(),
str_to_fr(
"0x14c54148a0940bb820957f5adf3fa1134ef5c4aaa113f4646458f270e0bfbfd0",
16,
),
)
.unwrap(),
str_to_fr(
"0x190d33b12f986f961e10c0ee44d8b9af11be25588cad89d416118e4bf4ebe80c",
16,
),
)
.unwrap(),
];
let mut expected_identity_path_index: Vec<u8> =
@ -281,19 +299,23 @@ mod test {
str_to_fr(
"0x22f98aa9ce704152ac17354914ad73ed1167ae6596af510aa5b3649325e06c92",
16,
),
)
.unwrap(),
str_to_fr(
"0x2a7c7c9b6ce5880b9f6f228d72bf6a575a526f29c66ecceef8b753d38bba7323",
16,
),
)
.unwrap(),
str_to_fr(
"0x2e8186e558698ec1c67af9c14d463ffc470043c9c2988b954d75dd643f36b992",
16,
),
)
.unwrap(),
str_to_fr(
"0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f",
16,
),
)
.unwrap(),
]);
expected_identity_path_index.append(&mut vec![0, 0, 0, 0]);
}
@ -302,7 +324,8 @@ mod test {
expected_path_elements.append(&mut vec![str_to_fr(
"0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca",
16,
)]);
)
.unwrap()]);
expected_identity_path_index.append(&mut vec![0]);
}
@ -319,7 +342,7 @@ mod test {
// We generate all relevant keys
let proving_key = zkey_from_folder(TEST_RESOURCES_FOLDER).unwrap();
let verification_key = vk_from_folder(TEST_RESOURCES_FOLDER).unwrap();
let builder = circom_from_folder(TEST_RESOURCES_FOLDER);
let builder = circom_from_folder(TEST_RESOURCES_FOLDER).unwrap();
// We compute witness from the json input example
let mut witness_json: &str = "";
@ -334,10 +357,12 @@ mod test {
let rln_witness = rln_witness_from_json(witness_json);
// Let's generate a zkSNARK proof
let proof = generate_proof(builder, &proving_key, &rln_witness).unwrap();
let rln_witness_unwrapped = rln_witness.unwrap();
let proof_values = proof_values_from_witness(&rln_witness);
// Let's generate a zkSNARK proof
let proof = generate_proof(builder, &proving_key, &rln_witness_unwrapped).unwrap();
let proof_values = proof_values_from_witness(&rln_witness_unwrapped);
// Let's verify the proof
let verified = verify_proof(&verification_key, &proof, &proof_values);
@ -378,7 +403,7 @@ mod test {
// We generate all relevant keys
let proving_key = zkey_from_folder(TEST_RESOURCES_FOLDER).unwrap();
let verification_key = vk_from_folder(TEST_RESOURCES_FOLDER).unwrap();
let builder = circom_from_folder(TEST_RESOURCES_FOLDER);
let builder = circom_from_folder(TEST_RESOURCES_FOLDER).unwrap();
// Let's generate a zkSNARK proof
let proof = generate_proof(builder, &proving_key, &rln_witness).unwrap();
@ -404,10 +429,10 @@ mod test {
witness_json = WITNESS_JSON_20;
}
let rln_witness = rln_witness_from_json(witness_json);
let rln_witness = rln_witness_from_json(witness_json).unwrap();
let ser = serialize_witness(&rln_witness);
let (deser, _) = deserialize_witness(&ser);
let ser = serialize_witness(&rln_witness).unwrap();
let (deser, _) = deserialize_witness(&ser).unwrap();
assert_eq!(rln_witness, deser);
// We test Proof values serialization
@ -429,11 +454,13 @@ mod test {
let expected_identity_secret_hash_seed_phrase = str_to_fr(
"0x20df38f3f00496f19fe7c6535492543b21798ed7cb91aebe4af8012db884eda3",
16,
);
)
.unwrap();
let expected_id_commitment_seed_phrase = str_to_fr(
"0x1223a78a5d66043a7f9863e14507dc80720a5602b2a894923e5b5147d5a9c325",
16,
);
)
.unwrap();
assert_eq!(
identity_secret_hash,
@ -449,11 +476,13 @@ mod test {
let expected_identity_secret_hash_seed_bytes = str_to_fr(
"0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716",
16,
);
)
.unwrap();
let expected_id_commitment_seed_bytes = str_to_fr(
"0xbf16d2b5c0d6f9d9d561e05bfca16a81b4b873bb063508fae360d8c74cef51f",
16,
);
)
.unwrap();
assert_eq!(
identity_secret_hash,

View File

@ -16,7 +16,7 @@ mod test {
let leaf_index = 3;
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
let mut rln = RLN::new(tree_height, input_buffer).unwrap();
// generate identity
let identity_secret_hash = hash_to_field(b"test-merkle-proof");
@ -38,6 +38,7 @@ mod test {
"0x1984f2e01184aef5cb974640898a5f5c25556554e2b06d99d4841badb8b198cd",
16
)
.unwrap()
);
} else if TEST_TREE_HEIGHT == 19 {
assert_eq!(
@ -46,6 +47,7 @@ mod test {
"0x219ceb53f2b1b7a6cf74e80d50d44d68ecb4a53c6cc65b25593c8d56343fb1fe",
16
)
.unwrap()
);
} else if TEST_TREE_HEIGHT == 20 {
assert_eq!(
@ -54,6 +56,7 @@ mod test {
"0x21947ffd0bce0c385f876e7c97d6a42eec5b1fe935aab2f01c1f8a8cbcc356d2",
16
)
.unwrap()
);
}
@ -62,71 +65,86 @@ mod test {
rln.get_proof(leaf_index, &mut buffer).unwrap();
let buffer_inner = buffer.into_inner();
let (path_elements, read) = bytes_le_to_vec_fr(&buffer_inner);
let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec());
let (path_elements, read) = bytes_le_to_vec_fr(&buffer_inner).unwrap();
let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec()).unwrap();
// We check correct computation of the path and indexes
let mut expected_path_elements = vec![
str_to_fr(
"0x0000000000000000000000000000000000000000000000000000000000000000",
16,
),
)
.unwrap(),
str_to_fr(
"0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864",
16,
),
)
.unwrap(),
str_to_fr(
"0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1",
16,
),
)
.unwrap(),
str_to_fr(
"0x18f43331537ee2af2e3d758d50f72106467c6eea50371dd528d57eb2b856d238",
16,
),
)
.unwrap(),
str_to_fr(
"0x07f9d837cb17b0d36320ffe93ba52345f1b728571a568265caac97559dbc952a",
16,
),
)
.unwrap(),
str_to_fr(
"0x2b94cf5e8746b3f5c9631f4c5df32907a699c58c94b2ad4d7b5cec1639183f55",
16,
),
)
.unwrap(),
str_to_fr(
"0x2dee93c5a666459646ea7d22cca9e1bcfed71e6951b953611d11dda32ea09d78",
16,
),
)
.unwrap(),
str_to_fr(
"0x078295e5a22b84e982cf601eb639597b8b0515a88cb5ac7fa8a4aabe3c87349d",
16,
),
)
.unwrap(),
str_to_fr(
"0x2fa5e5f18f6027a6501bec864564472a616b2e274a41211a444cbe3a99f3cc61",
16,
),
)
.unwrap(),
str_to_fr(
"0x0e884376d0d8fd21ecb780389e941f66e45e7acce3e228ab3e2156a614fcd747",
16,
),
)
.unwrap(),
str_to_fr(
"0x1b7201da72494f1e28717ad1a52eb469f95892f957713533de6175e5da190af2",
16,
),
)
.unwrap(),
str_to_fr(
"0x1f8d8822725e36385200c0b201249819a6e6e1e4650808b5bebc6bface7d7636",
16,
),
)
.unwrap(),
str_to_fr(
"0x2c5d82f66c914bafb9701589ba8cfcfb6162b0a12acf88a8d0879a0471b5f85a",
16,
),
)
.unwrap(),
str_to_fr(
"0x14c54148a0940bb820957f5adf3fa1134ef5c4aaa113f4646458f270e0bfbfd0",
16,
),
)
.unwrap(),
str_to_fr(
"0x190d33b12f986f961e10c0ee44d8b9af11be25588cad89d416118e4bf4ebe80c",
16,
),
)
.unwrap(),
];
let mut expected_identity_path_index: Vec<u8> =
@ -138,19 +156,23 @@ mod test {
str_to_fr(
"0x22f98aa9ce704152ac17354914ad73ed1167ae6596af510aa5b3649325e06c92",
16,
),
)
.unwrap(),
str_to_fr(
"0x2a7c7c9b6ce5880b9f6f228d72bf6a575a526f29c66ecceef8b753d38bba7323",
16,
),
)
.unwrap(),
str_to_fr(
"0x2e8186e558698ec1c67af9c14d463ffc470043c9c2988b954d75dd643f36b992",
16,
),
)
.unwrap(),
str_to_fr(
"0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f",
16,
),
)
.unwrap(),
]);
expected_identity_path_index.append(&mut vec![0, 0, 0, 0]);
}
@ -159,7 +181,8 @@ mod test {
expected_path_elements.append(&mut vec![str_to_fr(
"0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca",
16,
)]);
)
.unwrap()]);
expected_identity_path_index.append(&mut vec![0]);
}
@ -193,11 +216,13 @@ mod test {
let expected_identity_secret_hash_seed_bytes = str_to_fr(
"0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716",
16,
);
)
.unwrap();
let expected_id_commitment_seed_bytes = str_to_fr(
"0xbf16d2b5c0d6f9d9d561e05bfca16a81b4b873bb063508fae360d8c74cef51f",
16,
);
)
.unwrap();
assert_eq!(
identity_secret_hash,
@ -226,19 +251,23 @@ mod test {
let expected_identity_trapdoor_seed_bytes = str_to_fr(
"0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716",
16,
);
)
.unwrap();
let expected_identity_nullifier_seed_bytes = str_to_fr(
"0x1f18714c7bc83b5bca9e89d404cf6f2f585bc4c0f7ed8b53742b7e2b298f50b4",
16,
);
)
.unwrap();
let expected_identity_secret_hash_seed_bytes = str_to_fr(
"0x2aca62aaa7abaf3686fff2caf00f55ab9462dc12db5b5d4bcf3994e671f8e521",
16,
);
)
.unwrap();
let expected_id_commitment_seed_bytes = str_to_fr(
"0x68b66aa0a8320d2e56842581553285393188714c48f9b17acd198b4f1734c5c",
16,
);
)
.unwrap();
assert_eq!(identity_trapdoor, expected_identity_trapdoor_seed_bytes);
assert_eq!(identity_nullifier, expected_identity_nullifier_seed_bytes);
@ -276,7 +305,7 @@ mod test {
}
let expected_hash = utils_poseidon_hash(&inputs);
let mut input_buffer = Cursor::new(vec_fr_to_bytes_le(&inputs));
let mut input_buffer = Cursor::new(vec_fr_to_bytes_le(&inputs).unwrap());
let mut output_buffer = Cursor::new(Vec::<u8>::new());
public_poseidon_hash(&mut input_buffer, &mut output_buffer).unwrap();

View File

@ -12,7 +12,7 @@ use ark_groth16::{
};
use ark_relations::r1cs::SynthesisError;
use ark_std::UniformRand;
use color_eyre::Result;
use color_eyre::{Report, Result};
use ethers_core::types::U256;
use rand::{thread_rng, Rng};
use semaphore::{
@ -89,7 +89,7 @@ pub enum ProofError {
#[error("Error reading circuit key: {0}")]
CircuitKeyError(#[from] std::io::Error),
#[error("Error producing witness: {0}")]
WitnessError(color_eyre::Report),
WitnessError(Report),
#[error("Error producing proof: {0}")]
SynthesisError(#[from] SynthesisError),
#[error("Error converting public input: {0}")]

View File

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
ark-ff = { version = "0.3.0", default-features = false, features = ["asm"] }
num-bigint = { version = "0.4.3", default-features = false, features = ["rand"] }
color-eyre = "0.6.1"
[dev-dependencies]
ark-bn254 = { version = "0.3.0" }

View File

@ -16,13 +16,14 @@
#![allow(dead_code)]
use std::collections::HashMap;
use std::io;
use std::{
cmp::max,
fmt::Debug,
iter::{once, repeat, successors},
};
use color_eyre::{Report, Result};
/// In the Hasher trait we define the node type, the default leaf
/// and the hash function used to initialize a Merkle Tree implementation
pub trait Hasher {
@ -114,15 +115,12 @@ impl<H: Hasher> OptimalMerkleTree<H> {
}
// Sets a leaf at the specified tree index
pub fn set(&mut self, index: usize, leaf: H::Fr) -> io::Result<()> {
pub fn set(&mut self, index: usize, leaf: H::Fr) -> Result<()> {
if index >= self.capacity() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"index exceeds set size",
));
return Err(Report::msg("index exceeds set size"));
}
self.nodes.insert((self.depth, index), leaf);
self.recalculate_from(index);
self.recalculate_from(index)?;
self.next_index = max(self.next_index, index + 1);
Ok(())
}
@ -132,31 +130,28 @@ impl<H: Hasher> OptimalMerkleTree<H> {
&mut self,
start: usize,
leaves: I,
) -> io::Result<()> {
) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
// check if the range is valid
if start + leaves.len() > self.capacity() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"provided range exceeds set size",
));
return Err(Report::msg("provided range exceeds set size"));
}
for (i, leaf) in leaves.iter().enumerate() {
self.nodes.insert((self.depth, start + i), *leaf);
self.recalculate_from(start + i);
self.recalculate_from(start + i)?;
}
self.next_index = max(self.next_index, start + leaves.len());
Ok(())
}
// Sets a leaf at the next available index
pub fn update_next(&mut self, leaf: H::Fr) -> io::Result<()> {
pub fn update_next(&mut self, leaf: H::Fr) -> Result<()> {
self.set(self.next_index, leaf)?;
Ok(())
}
// Deletes a leaf at a certain index by setting it to its default value (next_index is not updated)
pub fn delete(&mut self, index: usize) -> io::Result<()> {
pub fn delete(&mut self, index: usize) -> Result<()> {
// We reset the leaf only if we previously set a leaf at that index
if index < self.next_index {
self.set(index, H::default_leaf())?;
@ -165,12 +160,9 @@ impl<H: Hasher> OptimalMerkleTree<H> {
}
// Computes a merkle proof the the leaf at the specified index
pub fn proof(&self, index: usize) -> io::Result<OptimalMerkleProof<H>> {
pub fn proof(&self, index: usize) -> Result<OptimalMerkleProof<H>> {
if index >= self.capacity() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"index exceeds set size",
));
return Err(Report::msg("index exceeds set size"));
}
let mut witness = Vec::<(H::Fr, u8)>::with_capacity(self.depth);
let mut i = index;
@ -184,17 +176,17 @@ impl<H: Hasher> OptimalMerkleTree<H> {
break;
}
}
assert_eq!(i, 0);
Ok(OptimalMerkleProof(witness))
if i != 0 {
Err(Report::msg("i != 0"))
} else {
Ok(OptimalMerkleProof(witness))
}
}
// Verifies a Merkle proof with respect to the input leaf and the tree root
pub fn verify(&self, leaf: &H::Fr, witness: &OptimalMerkleProof<H>) -> io::Result<bool> {
pub fn verify(&self, leaf: &H::Fr, witness: &OptimalMerkleProof<H>) -> Result<bool> {
if witness.length() != self.depth {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"witness length doesn't match tree depth",
));
return Err(Report::msg("witness length doesn't match tree depth"));
}
let expected_root = witness.compute_root_from(leaf);
Ok(expected_root.eq(&self.root()))
@ -219,7 +211,7 @@ impl<H: Hasher> OptimalMerkleTree<H> {
H::hash(&[self.get_node(depth, b), self.get_node(depth, b + 1)])
}
fn recalculate_from(&mut self, index: usize) {
fn recalculate_from(&mut self, index: usize) -> Result<()> {
let mut i = index;
let mut depth = self.depth;
loop {
@ -231,8 +223,13 @@ impl<H: Hasher> OptimalMerkleTree<H> {
break;
}
}
assert_eq!(depth, 0);
assert_eq!(i, 0);
if depth != 0 {
return Err(Report::msg("did not reach the depth"));
}
if i != 0 {
return Err(Report::msg("did not go through all indexes"));
}
Ok(())
}
}
@ -387,7 +384,7 @@ impl<H: Hasher> FullMerkleTree<H> {
}
// Sets a leaf at the specified tree index
pub fn set(&mut self, leaf: usize, hash: H::Fr) -> io::Result<()> {
pub fn set(&mut self, leaf: usize, hash: H::Fr) -> Result<()> {
self.set_range(leaf, once(hash))?;
self.next_index = max(self.next_index, leaf + 1);
Ok(())
@ -395,41 +392,34 @@ impl<H: Hasher> FullMerkleTree<H> {
// Sets tree nodes, starting from start index
// Function proper of FullMerkleTree implementation
fn set_range<I: IntoIterator<Item = H::Fr>>(
&mut self,
start: usize,
hashes: I,
) -> io::Result<()> {
fn set_range<I: IntoIterator<Item = H::Fr>>(&mut self, start: usize, hashes: I) -> Result<()> {
let index = self.capacity() + start - 1;
let mut count = 0;
// first count number of hashes, and check that they fit in the tree
// then insert into the tree
let hashes = hashes.into_iter().collect::<Vec<_>>();
if hashes.len() + start > self.capacity() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"provided hashes do not fit in the tree",
));
return Err(Report::msg("provided hashes do not fit in the tree"));
}
hashes.into_iter().for_each(|hash| {
self.nodes[index + count] = hash;
count += 1;
});
if count != 0 {
self.update_nodes(index, index + (count - 1));
self.update_nodes(index, index + (count - 1))?;
self.next_index = max(self.next_index, start + count);
}
Ok(())
}
// Sets a leaf at the next available index
pub fn update_next(&mut self, leaf: H::Fr) -> io::Result<()> {
pub fn update_next(&mut self, leaf: H::Fr) -> Result<()> {
self.set(self.next_index, leaf)?;
Ok(())
}
// Deletes a leaf at a certain index by setting it to its default value (next_index is not updated)
pub fn delete(&mut self, index: usize) -> io::Result<()> {
pub fn delete(&mut self, index: usize) -> Result<()> {
// We reset the leaf only if we previously set a leaf at that index
if index < self.next_index {
self.set(index, H::default_leaf())?;
@ -438,12 +428,9 @@ impl<H: Hasher> FullMerkleTree<H> {
}
// Computes a merkle proof the the leaf at the specified index
pub fn proof(&self, leaf: usize) -> io::Result<FullMerkleProof<H>> {
pub fn proof(&self, leaf: usize) -> Result<FullMerkleProof<H>> {
if leaf >= self.capacity() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"index exceeds set size",
));
return Err(Report::msg("index exceeds set size"));
}
let mut index = self.capacity() + leaf - 1;
let mut path = Vec::with_capacity(self.depth + 1);
@ -460,7 +447,7 @@ impl<H: Hasher> FullMerkleTree<H> {
}
// Verifies a Merkle proof with respect to the input leaf and the tree root
pub fn verify(&self, hash: &H::Fr, proof: &FullMerkleProof<H>) -> io::Result<bool> {
pub fn verify(&self, hash: &H::Fr, proof: &FullMerkleProof<H>) -> Result<bool> {
Ok(proof.compute_root_from(hash) == self.root())
}
@ -487,15 +474,18 @@ impl<H: Hasher> FullMerkleTree<H> {
(index + 2).next_power_of_two().trailing_zeros() as usize - 1
}
fn update_nodes(&mut self, start: usize, end: usize) {
debug_assert_eq!(self.levels(start), self.levels(end));
fn update_nodes(&mut self, start: usize, end: usize) -> Result<()> {
if self.levels(start) != self.levels(end) {
return Err(Report::msg("self.levels(start) != self.levels(end)"));
}
if let (Some(start), Some(end)) = (self.parent(start), self.parent(end)) {
for parent in start..=end {
let child = self.first_child(parent);
self.nodes[parent] = H::hash(&[self.nodes[child], self.nodes[child + 1]]);
}
self.update_nodes(start, end);
self.update_nodes(start, end)?;
}
Ok(())
}
}