diff --git a/rln/Cargo.toml b/rln/Cargo.toml index f45b096..49ee035 100644 --- a/rln/Cargo.toml +++ b/rln/Cargo.toml @@ -56,3 +56,8 @@ zkp-u256 = { version = "0.2", optional = true } ethers-core = "0.6.3" tiny-keccak = "2.0.2" + +blake2 = "0.8.1" + +# TODO Remove this and use arkworks instead +sapling-crypto = { package = "sapling-crypto_ce", version = "0.1.3", default-features = false } diff --git a/rln/src/lib.rs b/rln/src/lib.rs index 5f210f3..0cfe497 100644 --- a/rln/src/lib.rs +++ b/rln/src/lib.rs @@ -9,6 +9,8 @@ pub mod poseidon_tree; pub mod public; pub mod util; +pub mod poseidon; + #[cfg(test)] mod test { use super::*; diff --git a/rln/src/poseidon.rs b/rln/src/poseidon.rs new file mode 100644 index 0000000..47045da --- /dev/null +++ b/rln/src/poseidon.rs @@ -0,0 +1,233 @@ +// Adapted from https://github.com/kilic/rln/blob/master/src/poseidon.rs +// +use blake2::{Blake2s, Digest}; + +use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr}; +use sapling_crypto::bellman::pairing::Engine; + +// TODO: Using arkworks libs here instead +//use ff::{Field, PrimeField, PrimeFieldRepr}; +//use ark_ec::{PairingEngine as Engine}; + +#[derive(Clone)] +pub struct PoseidonParams { + rf: usize, + rp: usize, + t: usize, + round_constants: Vec, + mds_matrix: Vec, +} + +#[derive(Clone)] +pub struct Poseidon { + params: PoseidonParams, +} + +impl PoseidonParams { + pub fn new( + rf: usize, + rp: usize, + t: usize, + round_constants: Option>, + mds_matrix: Option>, + seed: Option>, + ) -> PoseidonParams { + let seed = match seed { + Some(seed) => seed, + None => b"".to_vec(), + }; + + let _round_constants = match round_constants { + Some(round_constants) => round_constants, + None => PoseidonParams::::generate_constants(b"drlnhdsc", seed.clone(), rf + rp), + }; + assert_eq!(rf + rp, _round_constants.len()); + + let _mds_matrix = match mds_matrix { + Some(mds_matrix) => mds_matrix, + None => PoseidonParams::::generate_mds_matrix(b"drlnhdsm", seed.clone(), t), + }; + PoseidonParams { + rf, + rp, + t, + round_constants: _round_constants, + mds_matrix: _mds_matrix, + } + } + + pub fn width(&self) -> usize { + return self.t; + } + + pub fn partial_round_len(&self) -> usize { + return self.rp; + } + + pub fn full_round_half_len(&self) -> usize { + return self.rf / 2; + } + + pub fn total_rounds(&self) -> usize { + return self.rf + self.rp; + } + + pub fn round_constant(&self, round: usize) -> E::Fr { + return self.round_constants[round]; + } + + pub fn mds_matrix_row(&self, i: usize) -> Vec { + let w = self.width(); + self.mds_matrix[i * w..(i + 1) * w].to_vec() + } + + pub fn mds_matrix(&self) -> Vec { + self.mds_matrix.clone() + } + + pub fn generate_mds_matrix(persona: &[u8; 8], seed: Vec, t: usize) -> Vec { + let v: Vec = PoseidonParams::::generate_constants(persona, seed, t * 2); + let mut matrix: Vec = Vec::with_capacity(t * t); + for i in 0..t { + for j in 0..t { + let mut tmp = v[i]; + tmp.add_assign(&v[t + j]); + let entry = tmp.inverse().unwrap(); + matrix.insert((i * t) + j, entry); + } + } + matrix + } + + pub fn generate_constants(persona: &[u8; 8], seed: Vec, len: usize) -> Vec { + let mut constants: Vec = Vec::new(); + let mut source = seed.clone(); + loop { + let mut hasher = Blake2s::new(); + hasher.input(persona); + hasher.input(source); + source = hasher.result().to_vec(); + let mut candidate_repr = ::Repr::default(); + candidate_repr.read_le(&source[..]).unwrap(); + if let Ok(candidate) = E::Fr::from_repr(candidate_repr) { + constants.push(candidate); + if constants.len() == len { + break; + } + } + } + constants + } +} + +impl Poseidon { + pub fn new(params: PoseidonParams) -> Poseidon { + Poseidon { params } + } + + pub fn hash(&self, inputs: Vec) -> E::Fr { + let mut state = inputs.clone(); + state.resize(self.t(), E::Fr::zero()); + let mut round_counter: usize = 0; + loop { + self.round(&mut state, round_counter); + round_counter += 1; + if round_counter == self.params.total_rounds() { + break; + } + } + state[0] + } + + fn t(&self) -> usize { + self.params.t + } + + fn round(&self, state: &mut Vec, round: usize) { + let a1 = self.params.full_round_half_len(); + let a2 = a1 + self.params.partial_round_len(); + let a3 = self.params.total_rounds(); + if round < a1 { + self.full_round(state, round); + } else if round >= a1 && round < a2 { + self.partial_round(state, round); + } else if round >= a2 && round < a3 { + if round == a3 - 1 { + self.full_round_last(state); + } else { + self.full_round(state, round); + } + } else { + panic!("should not be here") + } + } + + fn full_round(&self, state: &mut Vec, round: usize) { + self.add_round_constants(state, round); + self.apply_quintic_sbox(state, true); + self.mul_mds_matrix(state); + } + + fn full_round_last(&self, state: &mut Vec) { + let last_round = self.params.total_rounds() - 1; + self.add_round_constants(state, last_round); + self.apply_quintic_sbox(state, true); + } + + fn partial_round(&self, state: &mut Vec, round: usize) { + self.add_round_constants(state, round); + self.apply_quintic_sbox(state, false); + self.mul_mds_matrix(state); + } + + fn add_round_constants(&self, state: &mut Vec, round: usize) { + for (_, b) in state.iter_mut().enumerate() { + let c = self.params.round_constants[round]; + b.add_assign(&c); + } + } + + fn apply_quintic_sbox(&self, state: &mut Vec, full: bool) { + for s in state.iter_mut() { + let mut b = s.clone(); + b.square(); + b.square(); + s.mul_assign(&b); + if !full { + break; + } + } + } + + fn mul_mds_matrix(&self, state: &mut Vec) { + let w = self.params.t; + let mut new_state = vec![E::Fr::zero(); w]; + for (i, ns) in new_state.iter_mut().enumerate() { + for (j, s) in state.iter().enumerate() { + let mut tmp = s.clone(); + tmp.mul_assign(&self.params.mds_matrix[i * w + j]); + ns.add_assign(&tmp); + } + } + for (i, ns) in new_state.iter_mut().enumerate() { + state[i].clone_from(ns); + } + } +} + +#[test] +fn test_poseidon_hash() { + use sapling_crypto::bellman::pairing::bn256; + use sapling_crypto::bellman::pairing::bn256::{Bn256, Fr}; + let params = PoseidonParams::::new(8, 55, 3, None, None, None); + let hasher = Poseidon::::new(params); + let input1: Vec = ["0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); + let r1: Fr = hasher.hash(input1); + let input2: Vec = ["0", "0"] + .iter() + .map(|e| Fr::from_str(e).unwrap()) + .collect(); + let r2: Fr = hasher.hash(input2.to_vec()); + // println!("{:?}", r1); + assert_eq!(r1, r2, "just to see if internal state resets"); +}