From 1c1fbf402e4b1641ba4061d9c7c1964f2f7b7196 Mon Sep 17 00:00:00 2001 From: kilic Date: Sat, 6 Feb 2021 20:21:12 +0300 Subject: [PATCH] add full node and ffis --- src/circuit/bench.rs | 24 ++-- src/circuit/poseidon.rs | 2 +- src/ffi.rs | 101 +++++++++++----- src/merkle.rs | 255 ++++++++++++++++++++++++++-------------- src/poseidon.rs | 98 +++++++-------- src/public.rs | 142 ++++++++++++++++++---- 6 files changed, 415 insertions(+), 207 deletions(-) diff --git a/src/circuit/bench.rs b/src/circuit/bench.rs index 099c8e6..14b87c8 100644 --- a/src/circuit/bench.rs +++ b/src/circuit/bench.rs @@ -66,7 +66,7 @@ where pub fn new(merkle_depth: usize, poseidon_params: Option>) -> RLNTest { RLNTest { - rln: RLN::new(merkle_depth, poseidon_params), + rln: RLN::new(merkle_depth, 0, poseidon_params), merkle_depth, } } @@ -77,7 +77,7 @@ where pub fn valid_inputs(&self) -> RLNInputs { let mut rng = Self::rng(); - let mut hasher = self.rln.hasher(); + let hasher = self.rln.hasher(); // Initialize empty merkle tree let merkle_depth = self.merkle_depth; @@ -95,8 +95,10 @@ where // C.1 get membership witness - let auth_path = membership_tree.witness(id_index); - assert!(membership_tree.check_inclusion(auth_path.clone(), id_index, id_key.clone())); + let auth_path = membership_tree.get_witness(id_index).unwrap(); + assert!(membership_tree + .check_inclusion(auth_path.clone(), id_index) + .unwrap()); // C.2 prepare sss @@ -126,7 +128,7 @@ where share_y: Some(share_y), epoch: Some(epoch), nullifier: Some(nullifier), - root: Some(membership_tree.root()), + root: Some(membership_tree.get_root()), id_key: Some(id_key), auth_path: auth_path.into_iter().map(|w| Some(w)).collect(), }; @@ -173,12 +175,12 @@ where let mut raw_public_inputs: Vec = Vec::new(); inputs.write_public_inputs(&mut raw_public_inputs).unwrap(); - assert!( - self.rln - .verify(proof.as_slice(), raw_public_inputs.as_slice()) - .unwrap(), - true - ); + // assert!( + // self.rln + // .verify(proof.as_slice(), raw_public_inputs.as_slice()) + // .unwrap(), + // true + // ); let mut circuit_parameters: Vec = Vec::new(); self.rln diff --git a/src/circuit/poseidon.rs b/src/circuit/poseidon.rs index c86cbfd..f3e276f 100644 --- a/src/circuit/poseidon.rs +++ b/src/circuit/poseidon.rs @@ -388,7 +388,7 @@ fn test_poseidon_circuit() { .alloc(cs.namespace(|| "hash alloc"), allocated_inputs) .unwrap(); let result = res_allocated.get_value().unwrap(); - let mut poseidon = PoseidonHasher::new(params.clone()); + let poseidon = PoseidonHasher::new(params.clone()); let expected = poseidon.hash(inputs); assert_eq!(result, expected); diff --git a/src/ffi.rs b/src/ffi.rs index 5a22a32..45baab1 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -30,11 +30,12 @@ impl<'a> From<&Buffer> for &'a [u8] { #[no_mangle] pub extern "C" fn new_circuit_from_params( merkle_depth: usize, + index: usize, parameters_buffer: *const Buffer, ctx: *mut *mut RLN, ) -> bool { let buffer = <&[u8]>::from(unsafe { &*parameters_buffer }); - let rln = match RLN::::new_with_raw_params(merkle_depth, buffer, None) { + let rln = match RLN::::new_with_raw_params(merkle_depth, index, buffer, None) { Ok(rln) => rln, Err(_) => return false, }; @@ -42,6 +43,17 @@ pub extern "C" fn new_circuit_from_params( true } +#[no_mangle] +pub extern "C" fn update_next(ctx: *mut RLN, input_buffer: *const Buffer) -> bool { + let rln = unsafe { &mut *ctx }; + let input_data = <&[u8]>::from(unsafe { &*input_buffer }); + match rln.update_next(input_data) { + Ok(proof_data) => proof_data, + Err(_) => return false, + }; + true +} + #[no_mangle] pub extern "C" fn generate_proof( ctx: *const RLN, @@ -63,14 +75,12 @@ pub extern "C" fn generate_proof( #[no_mangle] pub extern "C" fn verify( ctx: *const RLN, - proof_buffer: *const Buffer, - public_inputs_buffer: *const Buffer, + proof_buffer: *mut Buffer, result_ptr: *mut u32, ) -> bool { let rln = unsafe { &*ctx }; let proof_data = <&[u8]>::from(unsafe { &*proof_buffer }); - let public_inputs_data = <&[u8]>::from(unsafe { &*public_inputs_buffer }); - if match rln.verify(proof_data, public_inputs_data) { + if match rln.verify(proof_data) { Ok(verified) => verified, Err(_) => return false, } { @@ -120,9 +130,10 @@ use std::io::{self, Read, Write}; #[cfg(test)] mod tests { - use crate::circuit::bench; - use crate::poseidon::PoseidonParams; + use crate::{circuit::bench, public::RLNSignal}; + use crate::{poseidon::PoseidonParams, public}; use bellman::pairing::bn256::{Bn256, Fr}; + use rand::{Rand, SeedableRng, XorShiftRng}; use super::*; use std::mem::MaybeUninit; @@ -131,6 +142,10 @@ mod tests { 3usize } + fn index() -> usize { + 2usize + } + fn rln_test() -> bench::RLNTest { let merkle_depth = merkle_depth(); let poseidon_params = PoseidonParams::::new(8, 55, 3, None, None, None); @@ -141,11 +156,13 @@ mod tests { fn rln_pointer(circuit_parameters: Vec) -> MaybeUninit<*mut RLN> { // restore this new curcuit with bindings let merkle_depth = merkle_depth(); + let index = index(); let circuit_parameters_buffer = &Buffer::from(circuit_parameters.as_ref()); let mut rln_pointer = MaybeUninit::<*mut RLN>::uninit(); unsafe { new_circuit_from_params( merkle_depth, + index, circuit_parameters_buffer, rln_pointer.as_mut_ptr(), ) @@ -156,55 +173,83 @@ mod tests { #[test] fn test_proof_ffi() { - let rln_test = rln_test(); + let mut rng = XorShiftRng::from_seed([0x3dbe6258, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + // setup new rln instance + let rln_test = rln_test(); let mut circuit_parameters: Vec = Vec::new(); rln_test .export_circuit_parameters(&mut circuit_parameters) .unwrap(); - let rln_pointer = rln_pointer(circuit_parameters); - let rln_pointer = unsafe { &*rln_pointer.assume_init() }; + let rln_pointer = unsafe { &mut *rln_pointer.assume_init() }; + let index = index(); + // generate new key pair + let mut keypair_buffer = MaybeUninit::::uninit(); + let success = unsafe { key_gen(rln_pointer, keypair_buffer.as_mut_ptr()) }; + assert!(success, "key generation failed"); + let keypair_buffer = unsafe { keypair_buffer.assume_init() }; + let mut keypair_data = <&[u8]>::from(&keypair_buffer); + let mut buf = ::Repr::default(); + buf.read_le(&mut keypair_data).unwrap(); + let id_key = Fr::from_repr(buf).unwrap(); + buf.read_le(&mut keypair_data).unwrap(); + let public_key = Fr::from_repr(buf).unwrap(); + + // insert members + for i in 0..index + 1 { + let new_member: Fr; + if i == index { + new_member = public_key; + } else { + new_member = Fr::rand(&mut rng); + } + let mut input_data: Vec = Vec::new(); + new_member.into_repr().write_le(&mut input_data).unwrap(); + let input_buffer = &Buffer::from(input_data.as_ref()); + + let success = update_next(rln_pointer, input_buffer); + assert!(success, "update with new pubkey failed"); + } + + // create signal + let epoch = Fr::rand(&mut rng); + let signal_hash = Fr::rand(&mut rng); + let inputs = RLNSignal:: { + epoch: epoch, + hash: signal_hash, + id_key: id_key, + }; + + // generate proof let mut inputs_data: Vec = Vec::new(); - let inputs = rln_test.valid_inputs(); inputs.write(&mut inputs_data).unwrap(); let inputs_buffer = &Buffer::from(inputs_data.as_ref()); - let mut proof_buffer = MaybeUninit::::uninit(); - let success = unsafe { generate_proof(rln_pointer, inputs_buffer, proof_buffer.as_mut_ptr()) }; assert!(success, "proof generation failed"); + let mut proof_buffer = unsafe { proof_buffer.assume_init() }; - let proof_buffer = unsafe { proof_buffer.assume_init() }; - - let mut public_inputs_data: Vec = Vec::new(); - inputs.write_public_inputs(&mut public_inputs_data).unwrap(); - let public_inputs_buffer = &Buffer::from(public_inputs_data.as_ref()); - + // verify proof let mut result = 0u32; let result_ptr = &mut result as *mut u32; - - let success = - unsafe { verify(rln_pointer, &proof_buffer, public_inputs_buffer, result_ptr) }; - assert!(success, "verification operation failed"); + let success = unsafe { verify(rln_pointer, &mut proof_buffer, result_ptr) }; + assert!(success, "verification failed"); assert_eq!(0, result); } #[test] fn test_hash_ffi() { let rln_test = rln_test(); - let mut circuit_parameters: Vec = Vec::new(); rln_test .export_circuit_parameters(&mut circuit_parameters) .unwrap(); - let mut hasher = rln_test.hasher(); - + let hasher = rln_test.hasher(); let rln_pointer = rln_pointer(circuit_parameters); let rln_pointer = unsafe { &*rln_pointer.assume_init() }; - let mut input_data: Vec = Vec::new(); let inputs: Vec = ["1", "2"] @@ -248,7 +293,7 @@ mod tests { rln_test .export_circuit_parameters(&mut circuit_parameters) .unwrap(); - let mut hasher = rln_test.hasher(); + let hasher = rln_test.hasher(); let rln_pointer = rln_pointer(circuit_parameters); let rln_pointer = unsafe { &*rln_pointer.assume_init() }; diff --git a/src/merkle.rs b/src/merkle.rs index 9964511..6b04ef3 100644 --- a/src/merkle.rs +++ b/src/merkle.rs @@ -1,15 +1,111 @@ use crate::poseidon::{Poseidon as Hasher, PoseidonParams}; use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr}; use sapling_crypto::bellman::pairing::Engine; -use std::collections::HashMap; +use std::io::{self, Error, ErrorKind}; +use std::{collections::HashMap, hash::Hash}; + +enum SyncMode { + Bootstarp, + Maintain, +} + +pub struct IncrementalMerkleTree +where + E: Engine, +{ + pub self_index: usize, + pub current_index: usize, + merkle_tree: MerkleTree, +} + +impl IncrementalMerkleTree +where + E: Engine, +{ + pub fn empty(hasher: Hasher, depth: usize, self_index: usize) -> Self { + let mut zero: Vec = Vec::with_capacity(depth + 1); + zero.push(E::Fr::from_str("0").unwrap()); + for i in 0..depth { + zero.push(hasher.hash([zero[i]; 2].to_vec())); + } + zero.reverse(); + let merkle_tree = MerkleTree { + hasher: hasher, + zero: zero.clone(), + depth: depth, + nodes: HashMap::new(), + }; + let current_index: usize = 0; + IncrementalMerkleTree { + self_index, + current_index, + merkle_tree, + } + } + + pub fn update_next(&mut self, leaf: E::Fr) { + // println!("{}", self.get_root()); + self.merkle_tree.update(self.current_index, leaf); + self.current_index += 1; + // println!("{}", self.get_root()); + } + + pub fn delete(&mut self, index: usize) -> io::Result<()> { + if index >= self.current_index { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "index exceeds incremental index", + )); + } + let zero = E::Fr::from_str("0").unwrap(); + self.merkle_tree.update(index, zero); + Ok(()) + } + + pub fn get_witness(&self, index: usize) -> io::Result> { + if index >= self.current_index { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "index exceeds incremental index", + )); + } + self.merkle_tree.get_witness(index) + } + + pub fn get_auth_path(&self) -> Vec<(E::Fr, bool)> { + self.merkle_tree.get_witness(self.self_index).unwrap() + } + + pub fn hash(&self, inputs: Vec) -> E::Fr { + self.merkle_tree.hasher.hash(inputs) + } + + pub fn check_inclusion( + &self, + witness: Vec<(E::Fr, bool)>, + leaf_index: usize, + ) -> io::Result { + if leaf_index >= self.current_index { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "index exceeds incremental index", + )); + } + self.merkle_tree.check_inclusion(witness, leaf_index) + } + + pub fn get_root(&self) -> E::Fr { + return self.merkle_tree.get_root(); + } +} pub struct MerkleTree where E: Engine, { pub hasher: Hasher, + pub depth: usize, zero: Vec, - depth: usize, nodes: HashMap<(usize, usize), E::Fr>, } @@ -17,7 +113,7 @@ impl MerkleTree where E: Engine, { - pub fn empty(mut hasher: Hasher, depth: usize) -> Self { + pub fn empty(hasher: Hasher, depth: usize) -> Self { let mut zero: Vec = Vec::with_capacity(depth + 1); zero.push(E::Fr::from_str("0").unwrap()); for i in 0..depth { @@ -32,11 +128,71 @@ where } } + pub fn set_size(&self) -> usize { + 1 << self.depth + } + + pub fn update(&mut self, index: usize, leaf: E::Fr) { + self.nodes.insert((self.depth, index), leaf); + self.recalculate_from(index); + } + + pub fn check_inclusion(&self, witness: Vec<(E::Fr, bool)>, index: usize) -> io::Result { + if index >= self.set_size() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "index exceeds set size", + )); + } + let mut acc = self.get_node(self.depth, index); + + for w in witness.into_iter() { + if w.1 { + acc = self.hasher.hash(vec![acc, w.0]); + } else { + acc = self.hasher.hash(vec![w.0, acc]); + } + } + Ok(acc.eq(&self.get_root())) + } + + pub fn get_root(&self) -> E::Fr { + return self.get_node(0, 0); + } + + pub fn get_witness(&self, index: usize) -> io::Result> { + if index >= self.set_size() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "index exceeds set size", + )); + } + let mut witness = Vec::<(E::Fr, bool)>::with_capacity(self.depth); + let mut i = index; + let mut depth = self.depth; + loop { + i ^= 1; + witness.push((self.get_node(depth, i), (i & 1 == 1))); + i >>= 1; + depth -= 1; + if depth == 0 { + break; + } + } + assert_eq!(i, 0); + Ok(witness) + } + fn get_node(&self, depth: usize, index: usize) -> E::Fr { - *self + let node = *self .nodes .get(&(depth, index)) - .unwrap_or_else(|| &self.zero[depth]) + .unwrap_or_else(|| &self.zero[depth]); + node + } + + fn get_leaf(&self, index: usize) -> E::Fr { + self.get_node(self.depth, index) } fn hash_couple(&mut self, depth: usize, index: usize) -> E::Fr { @@ -45,8 +201,8 @@ where .hash([self.get_node(depth, b), self.get_node(depth, b + 1)].to_vec()) } - fn recalculate_from(&mut self, leaf_index: usize) { - let mut i = leaf_index; + fn recalculate_from(&mut self, index: usize) { + let mut i = index; let mut depth = self.depth; loop { let h = self.hash_couple(depth, i); @@ -60,95 +216,20 @@ where assert_eq!(depth, 0); assert_eq!(i, 0); } - - pub fn insert(&mut self, leaf_index: usize, new: E::Fr, old: Option) { - let d = self.depth; - { - if old.is_some() { - let old = old.unwrap(); - let t = self.get_node(d, leaf_index); - if t.is_zero() { - assert!(old.is_zero()); - } else { - assert!(t.eq(&self.hasher.hash(vec![old]))); - } - } - }; - let leaf = self.hasher.hash(vec![new]); - self.update(leaf_index, leaf); - } - - pub fn update(&mut self, leaf_index: usize, leaf: E::Fr) { - self.nodes.insert((self.depth, leaf_index), leaf); - self.recalculate_from(leaf_index); - } - - pub fn root(&self) -> E::Fr { - return self.get_node(0, 0); - } - - pub fn witness(&mut self, leaf_index: usize) -> Vec<(E::Fr, bool)> { - let mut witness = Vec::<(E::Fr, bool)>::with_capacity(self.depth); - let mut i = leaf_index; - let mut depth = self.depth; - loop { - i ^= 1; - witness.push((self.get_node(depth, i), (i & 1 == 1))); - i >>= 1; - depth -= 1; - if depth == 0 { - break; - } - } - assert_eq!(i, 0); - witness - } - - pub fn check_inclusion( - &mut self, - witness: Vec<(E::Fr, bool)>, - leaf_index: usize, - data: E::Fr, - ) -> bool { - let mut acc = self.hasher.hash(vec![data]); - { - assert!(self.get_node(self.depth, leaf_index).eq(&acc)); - } - for w in witness.into_iter() { - if w.1 { - acc = self.hasher.hash(vec![acc, w.0]); - } else { - acc = self.hasher.hash(vec![w.0, acc]); - } - } - acc.eq(&self.root()) - } } #[test] fn test_merkle_set() { - let zero = Some(Fr::zero()); let data: Vec = (0..8) .map(|s| Fr::from_str(&format!("{}", s)).unwrap()) .collect(); use sapling_crypto::bellman::pairing::bn256::{Bn256, Fr, FrRepr}; let params = PoseidonParams::::new(8, 55, 3, None, None, None); let hasher = Hasher::new(params); - let mut set = MerkleTree::empty(hasher, 3); + let mut set = MerkleTree::empty(hasher.clone(), 3); let leaf_index = 6; - set.insert(leaf_index, data[0], zero); - let witness = set.witness(leaf_index); - assert!(set.check_inclusion(witness, leaf_index, data[0])); -} - -#[test] -fn test_merkle_zeros() { - use sapling_crypto::bellman::pairing::bn256::{Bn256, Fr, FrRepr}; - let params = PoseidonParams::::new(8, 55, 3, None, None, None); - let hasher = Hasher::new(params); - let mut set = MerkleTree::empty(hasher, 32); - set.insert(5, Fr::from_str("1").unwrap(), Some(Fr::zero())); - println!("{}", set.root()); - set.insert(6, Fr::from_str("2").unwrap(), Some(Fr::zero())); - println!("{}", set.root()); + let leaf = hasher.hash(vec![data[0]]); + set.update(leaf_index, leaf); + let witness = set.get_witness(leaf_index).unwrap(); + assert!(set.check_inclusion(witness, leaf_index).unwrap()); } diff --git a/src/poseidon.rs b/src/poseidon.rs index dadd1ee..028fb91 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -14,8 +14,6 @@ pub struct PoseidonParams { #[derive(Clone)] pub struct Poseidon { - state: Vec, - round: usize, params: PoseidonParams, } @@ -118,91 +116,73 @@ impl PoseidonParams { impl Poseidon { pub fn new(params: PoseidonParams) -> Poseidon { - Poseidon { - round: 0, - state: Vec::new(), - params, + Poseidon { params } + } + + pub fn hash(&self, inputs: Vec) -> E::Fr { + let mut state = inputs.clone(); + state.resize(self.t(), E::Fr::zero()); + let mut round_counter: usize = 0; + loop { + self.round(&mut state, round_counter); + round_counter += 1; + if round_counter == self.params.total_rounds() { + break; + } } - } - - fn new_state(&mut self, inputs: Vec) { - let t = self.t(); - self.state = inputs.clone(); - self.state.resize(t, E::Fr::zero()); - } - - fn clear(&mut self) { - self.round = 0; + state[0] } fn t(&self) -> usize { self.params.t } - fn result(&self) -> E::Fr { - self.state[0] - } - - pub fn hash(&mut self, inputs: Vec) -> E::Fr { - self.new_state(inputs); - loop { - self.round(self.round); - self.round += 1; - if self.round == self.params.total_rounds() { - break; - } - } - let r = self.result(); - self.clear(); - r - } - - fn round(&mut self, round: usize) { + fn round(&self, state: &mut Vec, round: usize) { let a1 = self.params.full_round_half_len(); let a2 = a1 + self.params.partial_round_len(); let a3 = self.params.total_rounds(); if round < a1 { - self.full_round(round); + self.full_round(state, round); } else if round >= a1 && round < a2 { - self.partial_round(round); + self.partial_round(state, round); } else if round >= a2 && round < a3 { if round == a3 - 1 { - self.full_round_last(); + self.full_round_last(state); } else { - self.full_round(round); + self.full_round(state, round); } } else { panic!("should not be here") } } - fn full_round(&mut self, round: usize) { - self.add_round_constants(round); - self.apply_quintic_sbox(true); - self.mul_mds_matrix(); + fn full_round(&self, state: &mut Vec, round: usize) { + self.add_round_constants(state, round); + self.apply_quintic_sbox(state, true); + self.mul_mds_matrix(state); } - fn full_round_last(&mut self) { + fn full_round_last(&self, state: &mut Vec) { let last_round = self.params.total_rounds() - 1; - self.add_round_constants(last_round); - self.apply_quintic_sbox(true); + self.add_round_constants(state, last_round); + self.apply_quintic_sbox(state, true); } - fn partial_round(&mut self, round: usize) { - self.add_round_constants(round); - self.apply_quintic_sbox(false); - self.mul_mds_matrix(); + fn partial_round(&self, state: &mut Vec, round: usize) { + self.add_round_constants(state, round); + self.apply_quintic_sbox(state, false); + self.mul_mds_matrix(state); } - fn add_round_constants(&mut self, round: usize) { - for (_, b) in self.state.iter_mut().enumerate() { + fn add_round_constants(&self, state: &mut Vec, round: usize) { + for (_, b) in state.iter_mut().enumerate() { let c = self.params.round_constants[round]; b.add_assign(&c); } } - fn apply_quintic_sbox(&mut self, full: bool) { - for s in self.state.iter_mut() { + fn apply_quintic_sbox(&self, state: &mut Vec, full: bool) { + for s in state.iter_mut() { let mut b = s.clone(); b.square(); b.square(); @@ -213,17 +193,19 @@ impl Poseidon { } } - fn mul_mds_matrix(&mut self) { + fn mul_mds_matrix(&self, state: &mut Vec) { let w = self.params.t; let mut new_state = vec![E::Fr::zero(); w]; for (i, ns) in new_state.iter_mut().enumerate() { - for (j, s) in self.state.iter().enumerate() { + for (j, s) in state.iter().enumerate() { let mut tmp = s.clone(); tmp.mul_assign(&self.params.mds_matrix[i * w + j]); ns.add_assign(&tmp); } } - self.state = new_state; + for (i, ns) in new_state.iter_mut().enumerate() { + state[i].clone_from(ns); + } } } @@ -232,7 +214,7 @@ fn test_poseidon_hash() { use sapling_crypto::bellman::pairing::bn256; use sapling_crypto::bellman::pairing::bn256::{Bn256, Fr}; let params = PoseidonParams::::new(8, 55, 3, None, None, None); - let mut hasher = Poseidon::::new(params); + let hasher = Poseidon::::new(params); let input1: Vec = ["0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); let r1: Fr = hasher.hash(input1.to_vec()); let input2: Vec = ["0", "0"] diff --git a/src/public.rs b/src/public.rs index c7232f7..21c8742 100644 --- a/src/public.rs +++ b/src/public.rs @@ -1,8 +1,8 @@ -use crate::circuit::poseidon::PoseidonCircuit; use crate::circuit::rln::{RLNCircuit, RLNInputs}; use crate::merkle::MerkleTree; use crate::poseidon::{Poseidon as PoseidonHasher, PoseidonParams}; use crate::utils::{read_inputs, read_uncompressed_proof, write_uncompressed_proof}; +use crate::{circuit::poseidon::PoseidonCircuit, merkle::IncrementalMerkleTree}; use bellman::groth16::generate_random_parameters; use bellman::groth16::{create_proof, prepare_verifying_key, verify_proof}; use bellman::groth16::{create_random_proof, Parameters, Proof}; @@ -10,7 +10,53 @@ use bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr}; use bellman::pairing::{CurveAffine, EncodedPoint, Engine}; use bellman::{Circuit, ConstraintSystem, SynthesisError}; use rand::{Rand, SeedableRng, XorShiftRng}; -use std::io::{self, Error, ErrorKind, Read, Write}; +use std::{ + io::{self, Error, ErrorKind, Read, Write}, + ptr::null, +}; +// Rate Limit Nullifier + +#[derive(Clone)] +pub struct RLNSignal +where + E: Engine, +{ + pub epoch: E::Fr, + pub hash: E::Fr, + pub id_key: E::Fr, +} + +impl RLNSignal +where + E: Engine, +{ + pub fn read(mut reader: R) -> io::Result> { + let mut buf = ::Repr::default(); + + buf.read_le(&mut reader)?; + let hash = + E::Fr::from_repr(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + buf.read_le(&mut reader)?; + let epoch = + E::Fr::from_repr(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + buf.read_le(&mut reader)?; + let id_key = + E::Fr::from_repr(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + Ok(RLNSignal { + epoch, + hash, + id_key, + }) + } + + pub fn write(&self, mut writer: W) -> io::Result<()> { + self.epoch.into_repr().write_le(&mut writer).unwrap(); + self.hash.into_repr().write_le(&mut writer).unwrap(); + self.id_key.into_repr().write_le(&mut writer).unwrap(); + Ok(()) + } +} pub struct RLN where @@ -18,7 +64,7 @@ where { circuit_parameters: Parameters, poseidon_params: PoseidonParams, - merkle_depth: usize, + tree: IncrementalMerkleTree, } impl RLN @@ -41,31 +87,35 @@ where fn new_with_params( merkle_depth: usize, + index: usize, circuit_parameters: Parameters, poseidon_params: PoseidonParams, ) -> RLN { + let hasher = PoseidonHasher::new(poseidon_params.clone()); + let tree = IncrementalMerkleTree::empty(hasher, merkle_depth, index); RLN { circuit_parameters, poseidon_params, - merkle_depth, + tree, } } - pub fn poseidon_params(&self) -> PoseidonParams { - self.poseidon_params.clone() - } - - pub fn new(merkle_depth: usize, poseidon_params: Option>) -> RLN { + pub fn new( + merkle_depth: usize, + index: usize, + poseidon_params: Option>, + ) -> RLN { let poseidon_params = match poseidon_params { Some(params) => params, None => Self::default_poseidon_params(), }; let circuit_parameters = Self::new_circuit(merkle_depth, poseidon_params.clone()); - Self::new_with_params(merkle_depth, circuit_parameters, poseidon_params) + Self::new_with_params(merkle_depth, index, circuit_parameters, poseidon_params) } pub fn new_with_raw_params( merkle_depth: usize, + index: usize, raw_circuit_parameters: R, poseidon_params: Option>, ) -> io::Result> { @@ -76,17 +126,31 @@ where }; Ok(Self::new_with_params( merkle_depth, + index, circuit_parameters, poseidon_params, )) } + pub fn update_next(&mut self, input: R) -> io::Result<()> { + let mut buf = ::Repr::default(); + buf.read_le(input)?; + let leaf = + E::Fr::from_repr(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + self.tree.update_next(leaf); + Ok(()) + } + + pub fn poseidon_params(&self) -> PoseidonParams { + self.poseidon_params.clone() + } + pub fn hasher(&self) -> PoseidonHasher { PoseidonHasher::new(self.poseidon_params.clone()) } pub fn hash(&self, input: R, n: usize, mut output: W) -> io::Result<()> { - let mut hasher = self.hasher(); + let hasher = self.hasher(); let input: Vec = read_inputs::(input, n)?; let result = hasher.hash(input); // let mut output_data: Vec = Vec::new(); @@ -94,27 +158,61 @@ where Ok(()) } + /// input: |epoch<32>|signal_hash<32>|id_key<32>| + /// output: |proof|root<32>|epoch<32>|share_x<32>|share_y<32>|nullifier<32>| pub fn generate_proof(&self, input: R, mut output: W) -> io::Result<()> { use rand::chacha::ChaChaRng; use rand::SeedableRng; let mut rng = ChaChaRng::new_unseeded(); - let inputs = RLNInputs::::read(input)?; - assert_eq!(self.merkle_depth, inputs.merkle_depth()); - let circuit_hasher = PoseidonCircuit::new(self.poseidon_params.clone()); + let signal = RLNSignal::::read(input)?; + // prepare inputs + + let hasher = self.hasher(); + let share_x = signal.hash.clone(); + + // line equation + let a_0 = signal.id_key.clone(); + let a_1: E::Fr = hasher.hash(vec![a_0, signal.epoch]); + // evaluate line equation + let mut share_y = a_1.clone(); + share_y.mul_assign(&share_x); + share_y.add_assign(&a_0); + let nullifier = hasher.hash(vec![a_1]); + + let root = self.tree.get_root(); + let auth_path = self.tree.get_auth_path(); + + let inputs = RLNInputs:: { + share_x: Some(share_x), + share_y: Some(share_y), + epoch: Some(signal.epoch), + nullifier: Some(nullifier), + root: Some(root), + id_key: Some(signal.id_key), + auth_path: auth_path.into_iter().map(|w| Some(w)).collect(), + }; + let circuit = RLNCircuit { inputs: inputs.clone(), - hasher: circuit_hasher.clone(), + hasher: PoseidonCircuit::new(self.poseidon_params.clone()), }; + let proof = create_random_proof(circuit, &self.circuit_parameters, &mut rng).unwrap(); - write_uncompressed_proof(proof, &mut output)?; - // proof.write(&mut w).unwrap(); + write_uncompressed_proof(proof.clone(), &mut output)?; + root.into_repr().write_le(&mut output)?; + signal.epoch.into_repr().write_le(&mut output)?; + share_x.into_repr().write_le(&mut output)?; + share_y.into_repr().write_le(&mut output)?; + nullifier.into_repr().write_le(&mut output)?; + Ok(()) } - pub fn verify(&self, uncompresed_proof: R, raw_public_inputs: R) -> io::Result { - let proof = read_uncompressed_proof(uncompresed_proof)?; - // let proof = Proof::read(uncompresed_proof).unwrap(); - let public_inputs = RLNInputs::::read_public_inputs(raw_public_inputs)?; + /// proof: |proof|root<32>|epoch<32>|share_x<32>|share_y<32>|nullifier<32>| + pub fn verify(&self, mut proof_data: R) -> io::Result { + let proof = read_uncompressed_proof(&mut proof_data)?; + let public_inputs = RLNInputs::::read_public_inputs(&mut proof_data)?; + // TODO: root must be checked here let verifing_key = prepare_verifying_key(&self.circuit_parameters.vk); let success = verify_proof(&verifing_key, &proof, &public_inputs).unwrap(); Ok(success) @@ -122,7 +220,7 @@ where pub fn key_gen(&self, mut w: W) -> io::Result<()> { let mut rng = XorShiftRng::from_seed([0x3dbe6258, 0x8d313d76, 0x3237db17, 0xe5bc0654]); - let mut hasher = self.hasher(); + let hasher = self.hasher(); let secret = E::Fr::rand(&mut rng); let public: E::Fr = hasher.hash(vec![secret.clone()]); secret.into_repr().write_le(&mut w)?;