From 826a02f652dc211963d6eac4a08dbe9ce9094afa Mon Sep 17 00:00:00 2001 From: Oskar Thoren Date: Wed, 16 Mar 2022 15:54:42 +0800 Subject: [PATCH] feat(rln): add hash, merkle and poseidon tree Import from existing semaphore-rs --- rln/Cargo.toml | 5 + rln/src/hash.rs | 216 ++++++++++++++++++++++++ rln/src/lib.rs | 3 + rln/src/merkle_tree.rs | 355 +++++++++++++++++++++++++++++++++++++++ rln/src/poseidon_tree.rs | 83 +++++++++ 5 files changed, 662 insertions(+) create mode 100644 rln/src/hash.rs create mode 100644 rln/src/merkle_tree.rs create mode 100644 rln/src/poseidon_tree.rs diff --git a/rln/Cargo.toml b/rln/Cargo.toml index 00a7a0b..f45b096 100644 --- a/rln/Cargo.toml +++ b/rln/Cargo.toml @@ -51,3 +51,8 @@ poseidon-rs = "0.0.8" sha2 = "0.10.1" ff = { package="ff_ce", version="0.11"} + +zkp-u256 = { version = "0.2", optional = true } +ethers-core = "0.6.3" + +tiny-keccak = "2.0.2" diff --git a/rln/src/hash.rs b/rln/src/hash.rs new file mode 100644 index 0000000..4428f76 --- /dev/null +++ b/rln/src/hash.rs @@ -0,0 +1,216 @@ +// Adapted from https://github.com/worldcoin/semaphore-rs/blob/main/src/hash.rs +// +use ethers_core::types::U256; +use num_bigint::{BigInt, Sign}; +use serde::{ + de::{Error as DeError, Visitor}, + ser::Error as _, + Deserialize, Serialize, +}; +use std::{ + fmt::{Debug, Display, Formatter, Result as FmtResult}, + str::{from_utf8, FromStr}, +}; + +/// Container for 256-bit hash values. +#[derive(Clone, Copy, PartialEq, Eq, Default)] +pub struct Hash(pub [u8; 32]); + +impl Hash { + #[must_use] + pub const fn from_bytes_be(bytes: [u8; 32]) -> Self { + Self(bytes) + } + + #[must_use] + pub const fn as_bytes_be(&self) -> &[u8; 32] { + &self.0 + } +} + +/// Debug print hashes using `hex!(..)` literals. +impl Debug for Hash { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "Hash(hex!(\"{}\"))", hex::encode(&self.0)) + } +} + +/// Display print hashes as `0x...`. +impl Display for Hash { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "0x{}", hex::encode(&self.0)) + } +} + +/// Conversion from Ether U256 +impl From<&Hash> for U256 { + fn from(hash: &Hash) -> Self { + Self::from_big_endian(hash.as_bytes_be()) + } +} + +/// Conversion to Ether U256 +impl From for Hash { + fn from(u256: U256) -> Self { + let mut bytes = [0_u8; 32]; + u256.to_big_endian(&mut bytes); + Self::from_bytes_be(bytes) + } +} + +/// Conversion from vec +impl From> for Hash { + fn from(vec: Vec) -> Self { + let mut bytes = [0_u8; 32]; + bytes.copy_from_slice(&vec[0..32]); + Self::from_bytes_be(bytes) + } +} + +/// Conversion to BigInt +impl From for BigInt { + fn from(hash: Hash) -> Self { + Self::from_bytes_be(Sign::Plus, hash.as_bytes_be()) + } +} + +impl From<&Hash> for BigInt { + fn from(hash: &Hash) -> Self { + Self::from_bytes_be(Sign::Plus, hash.as_bytes_be()) + } +} + +/// Parse Hash from hex string. +/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix +/// but they must always be exactly 32 bytes. +impl FromStr for Hash { + type Err = hex::FromHexError; + + fn from_str(s: &str) -> Result { + let str = trim_hex_prefix(s); + let mut out = [0_u8; 32]; + hex::decode_to_slice(str, &mut out)?; + Ok(Self(out)) + } +} + +/// Serialize hashes into human readable hex strings or byte arrays. +/// Hex strings are lower case without prefix and always 32 bytes. +impl Serialize for Hash { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + if serializer.is_human_readable() { + let mut hex_ascii = [0_u8; 64]; + hex::encode_to_slice(self.0, &mut hex_ascii) + .map_err(|e| S::Error::custom(format!("Error hex encoding: {}", e)))?; + from_utf8(&hex_ascii) + .map_err(|e| S::Error::custom(format!("Invalid hex encoding: {}", e)))? + .serialize(serializer) + } else { + self.0.serialize(serializer) + } + } +} + +/// Deserialize human readable hex strings or byte arrays into hashes. +/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix +/// but they must always be exactly 32 bytes. +impl<'de> Deserialize<'de> for Hash { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + if deserializer.is_human_readable() { + deserializer.deserialize_str(HashStrVisitor) + } else { + <[u8; 32]>::deserialize(deserializer).map(Hash) + } + } +} + +struct HashStrVisitor; + +impl<'de> Visitor<'de> for HashStrVisitor { + type Value = Hash; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("a 32 byte hex string") + } + + fn visit_borrowed_str(self, value: &'de str) -> Result + where + E: DeError, + { + Hash::from_str(value).map_err(|e| E::custom(format!("Error in hex: {}", e))) + } + + fn visit_str(self, value: &str) -> Result + where + E: DeError, + { + Hash::from_str(value).map_err(|e| E::custom(format!("Error in hex: {}", e))) + } + + fn visit_string(self, value: String) -> Result + where + E: DeError, + { + Hash::from_str(&value).map_err(|e| E::custom(format!("Error in hex: {}", e))) + } +} + +/// Helper function to optionally remove `0x` prefix from hex strings. +fn trim_hex_prefix(str: &str) -> &str { + if str.len() >= 2 && (&str[..2] == "0x" || &str[..2] == "0X") { + &str[2..] + } else { + str + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use hex_literal::hex; + use serde_json::{from_str, to_string}; + + #[test] + fn test_serialize() { + let hash = Hash([0; 32]); + assert_eq!( + to_string(&hash).unwrap(), + "\"0000000000000000000000000000000000000000000000000000000000000000\"" + ); + let hash = Hash(hex!( + "1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe" + )); + assert_eq!( + to_string(&hash).unwrap(), + "\"1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe\"" + ); + } + + #[test] + fn test_deserialize() { + assert_eq!( + from_str::( + "\"0x1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe\"" + ) + .unwrap(), + Hash(hex!( + "1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe" + )) + ); + assert_eq!( + from_str::( + "\"0X1C4823575d154474EE3e5ac838d002456a815181437afd14f126da58a9912bbe\"" + ) + .unwrap(), + Hash(hex!( + "1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe" + )) + ); + } +} diff --git a/rln/src/lib.rs b/rln/src/lib.rs index 79fb4e4..7fe6670 100644 --- a/rln/src/lib.rs +++ b/rln/src/lib.rs @@ -3,5 +3,8 @@ pub mod ffi; pub mod identity; +pub mod hash; +pub mod merkle_tree; +pub mod poseidon_tree; pub mod public; pub mod util; diff --git a/rln/src/merkle_tree.rs b/rln/src/merkle_tree.rs new file mode 100644 index 0000000..9d4bcb1 --- /dev/null +++ b/rln/src/merkle_tree.rs @@ -0,0 +1,355 @@ +// Adapted from https://github.com/worldcoin/semaphore-rs/blob/main/src/merkle_tree.rs +// +//! Implements basic binary Merkle trees +//! +//! # To do +//! +//! * Disk based storage backend (using mmaped files should be easy) + +use num_bigint::BigInt; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::Debug, + iter::{once, repeat, successors}, +}; + +/// Hash types, values and algorithms for a Merkle tree +pub trait Hasher { + /// Type of the leaf and node hashes + type Hash: Clone + Eq + Serialize; + + /// Compute the hash of an intermediate node + fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash; +} + +/// Merkle tree with all leaf and intermediate hashes stored +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct MerkleTree { + /// Depth of the tree, # of layers including leaf layer + depth: usize, + + /// Hash value of empty subtrees of given depth, starting at leaf level + empty: Vec, + + /// Hash values of tree nodes and leaves, breadth first order + nodes: Vec, +} + +/// Element of a Merkle proof +#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Branch { + /// Left branch taken, value is the right sibling hash. + Left(H::Hash), + + /// Right branch taken, value is the left sibling hash. + Right(H::Hash), +} + +/// Merkle proof path, bottom to top. +#[derive(Clone, PartialEq, Eq, Serialize)] +pub struct Proof(pub Vec>); + +/// For a given node index, return the parent node index +/// Returns None if there is no parent (root node) +const fn parent(index: usize) -> Option { + if index == 0 { + None + } else { + Some(((index + 1) >> 1) - 1) + } +} + +/// For a given node index, return index of the first (left) child. +const fn first_child(index: usize) -> usize { + (index << 1) + 1 +} + +const fn depth(index: usize) -> usize { + // `n.next_power_of_two()` will return `n` iff `n` is a power of two. + // The extra offset corrects this. + (index + 2).next_power_of_two().trailing_zeros() as usize - 1 +} + +impl MerkleTree { + /// Creates a new `MerkleTree` + /// * `depth` - The depth of the tree, including the root. This is 1 greater + /// than the `treeLevels` argument to the Semaphore contract. + pub fn new(depth: usize, initial_leaf: H::Hash) -> Self { + // Compute empty node values, leaf to root + let empty = successors(Some(initial_leaf), |prev| Some(H::hash_node(prev, prev))) + .take(depth) + .collect::>(); + + // Compute node values + let nodes = empty + .iter() + .rev() + .enumerate() + .flat_map(|(depth, hash)| repeat(hash).take(1 << depth)) + .cloned() + .collect::>(); + debug_assert!(nodes.len() == (1 << depth) - 1); + + Self { + depth, + empty, + nodes, + } + } + + pub fn num_leaves(&self) -> usize { + self.depth + .checked_sub(1) + .map(|n| 1 << n) + .unwrap_or_default() + } + + pub fn root(&self) -> H::Hash { + self.nodes[0].clone() + } + + pub fn set(&mut self, leaf: usize, hash: H::Hash) { + self.set_range(leaf, once(hash)); + } + + pub fn set_range>(&mut self, start: usize, hashes: I) { + let index = self.num_leaves() + start - 1; + let mut count = 0; + // TODO: Error/panic when hashes is longer than available leafs + for (leaf, hash) in self.nodes[index..].iter_mut().zip(hashes) { + *leaf = hash; + count += 1; + } + if count != 0 { + self.update_nodes(index, index + (count - 1)); + } + } + + fn update_nodes(&mut self, start: usize, end: usize) { + debug_assert_eq!(depth(start), depth(end)); + if let (Some(start), Some(end)) = (parent(start), parent(end)) { + for parent in start..=end { + let child = first_child(parent); + self.nodes[parent] = H::hash_node(&self.nodes[child], &self.nodes[child + 1]); + } + self.update_nodes(start, end); + } + } + + pub fn proof(&self, leaf: usize) -> Option> { + if leaf >= self.num_leaves() { + return None; + } + let mut index = self.num_leaves() + leaf - 1; + let mut path = Vec::with_capacity(self.depth); + while let Some(parent) = parent(index) { + // Add proof for node at index to parent + path.push(match index & 1 { + 1 => Branch::Left(self.nodes[index + 1].clone()), + 0 => Branch::Right(self.nodes[index - 1].clone()), + _ => unreachable!(), + }); + index = parent; + } + Some(Proof(path)) + } + + #[allow(dead_code)] + pub fn verify(&self, hash: H::Hash, proof: &Proof) -> bool { + proof.root(hash) == self.root() + } + + #[allow(dead_code)] + pub fn leaves(&self) -> &[H::Hash] { + &self.nodes[(self.num_leaves() - 1)..] + } +} + +impl Proof { + /// Compute the leaf index for this proof + #[allow(dead_code)] + pub fn leaf_index(&self) -> usize { + self.0.iter().rev().fold(0, |index, branch| match branch { + Branch::Left(_) => index << 1, + Branch::Right(_) => (index << 1) + 1, + }) + } + + /// Compute path index (TODO: do we want to keep this here?) + #[allow(dead_code)] + pub fn path_index(&self) -> Vec { + self.0 + .iter() + .map(|branch| match branch { + Branch::Left(_) => BigInt::from(0), + Branch::Right(_) => BigInt::from(1), + }) + .collect() + } + + /// Compute the Merkle root given a leaf hash + #[allow(dead_code)] + pub fn root(&self, hash: H::Hash) -> H::Hash { + self.0.iter().fold(hash, |hash, branch| match branch { + Branch::Left(sibling) => H::hash_node(&hash, sibling), + Branch::Right(sibling) => H::hash_node(sibling, &hash), + }) + } +} + +impl Debug for Branch +where + H: Hasher, + H::Hash: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Left(arg0) => f.debug_tuple("Left").field(arg0).finish(), + Self::Right(arg0) => f.debug_tuple("Right").field(arg0).finish(), + } + } +} + +impl Debug for Proof +where + H: Hasher, + H::Hash: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Proof").field(&self.0).finish() + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use hex_literal::hex; + use tiny_keccak::{Hasher as _, Keccak}; + + struct Keccak256; + + impl Hasher for Keccak256 { + type Hash = [u8; 32]; + + fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash { + let mut output = [0; 32]; + let mut hasher = Keccak::v256(); + hasher.update(left); + hasher.update(right); + hasher.finalize(&mut output); + output + } + } + + #[test] + fn test_index_calculus() { + assert_eq!(parent(0), None); + assert_eq!(parent(1), Some(0)); + assert_eq!(parent(2), Some(0)); + assert_eq!(parent(3), Some(1)); + assert_eq!(parent(4), Some(1)); + assert_eq!(parent(5), Some(2)); + assert_eq!(parent(6), Some(2)); + assert_eq!(first_child(0), 1); + assert_eq!(first_child(2), 5); + assert_eq!(depth(0), 0); + assert_eq!(depth(1), 1); + assert_eq!(depth(2), 1); + assert_eq!(depth(3), 2); + assert_eq!(depth(6), 2); + } + + #[test] + fn test_root() { + let mut tree = MerkleTree::::new(3, [0; 32]); + assert_eq!( + tree.root(), + hex!("b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30") + ); + tree.set( + 0, + hex!("0000000000000000000000000000000000000000000000000000000000000001"), + ); + assert_eq!( + tree.root(), + hex!("c1ba1812ff680ce84c1d5b4f1087eeb08147a4d510f3496b2849df3a73f5af95") + ); + tree.set( + 1, + hex!("0000000000000000000000000000000000000000000000000000000000000002"), + ); + assert_eq!( + tree.root(), + hex!("893760ec5b5bee236f29e85aef64f17139c3c1b7ff24ce64eb6315fca0f2485b") + ); + tree.set( + 2, + hex!("0000000000000000000000000000000000000000000000000000000000000003"), + ); + assert_eq!( + tree.root(), + hex!("222ff5e0b5877792c2bc1670e2ccd0c2c97cd7bb1672a57d598db05092d3d72c") + ); + tree.set( + 3, + hex!("0000000000000000000000000000000000000000000000000000000000000004"), + ); + assert_eq!( + tree.root(), + hex!("a9bb8c3f1f12e9aa903a50c47f314b57610a3ab32f2d463293f58836def38d36") + ); + } + + #[test] + fn test_proof() { + let mut tree = MerkleTree::::new(3, [0; 32]); + tree.set( + 0, + hex!("0000000000000000000000000000000000000000000000000000000000000001"), + ); + tree.set( + 1, + hex!("0000000000000000000000000000000000000000000000000000000000000002"), + ); + tree.set( + 2, + hex!("0000000000000000000000000000000000000000000000000000000000000003"), + ); + tree.set( + 3, + hex!("0000000000000000000000000000000000000000000000000000000000000004"), + ); + + let proof = tree.proof(2).expect("proof should exist"); + assert_eq!(proof.leaf_index(), 2); + assert!(tree.verify( + hex!("0000000000000000000000000000000000000000000000000000000000000003"), + &proof + )); + assert!(!tree.verify( + hex!("0000000000000000000000000000000000000000000000000000000000000001"), + &proof + )); + } + + #[test] + fn test_position() { + let mut tree = MerkleTree::::new(3, [0; 32]); + tree.set( + 0, + hex!("0000000000000000000000000000000000000000000000000000000000000001"), + ); + tree.set( + 1, + hex!("0000000000000000000000000000000000000000000000000000000000000002"), + ); + tree.set( + 2, + hex!("0000000000000000000000000000000000000000000000000000000000000003"), + ); + tree.set( + 3, + hex!("0000000000000000000000000000000000000000000000000000000000000004"), + ); + } +} diff --git a/rln/src/poseidon_tree.rs b/rln/src/poseidon_tree.rs new file mode 100644 index 0000000..48aff88 --- /dev/null +++ b/rln/src/poseidon_tree.rs @@ -0,0 +1,83 @@ +// Adapted from https://github.com/worldcoin/semaphore-rs/blob/main/src/poseidon_tree.rs +// +use crate::{ + hash::Hash, + merkle_tree::{self, Hasher, MerkleTree}, +}; +use ff::{PrimeField, PrimeFieldRepr}; +use once_cell::sync::Lazy; +use poseidon_rs::{Fr, FrRepr, Poseidon}; +use serde::{Deserialize, Serialize}; + +static POSEIDON: Lazy = Lazy::new(Poseidon::new); + +#[allow(dead_code)] +pub type PoseidonTree = MerkleTree; +#[allow(dead_code)] +pub type Branch = merkle_tree::Branch; +#[allow(dead_code)] +pub type Proof = merkle_tree::Proof; + +#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct PoseidonHash; + +#[allow(clippy::fallible_impl_from)] // TODO +impl From<&Hash> for Fr { + fn from(hash: &Hash) -> Self { + let mut repr = FrRepr::default(); + repr.read_be(&hash.as_bytes_be()[..]).unwrap(); + Self::from_repr(repr).unwrap() + } +} + +#[allow(clippy::fallible_impl_from)] // TODO +impl From for Hash { + fn from(fr: Fr) -> Self { + let mut bytes = [0_u8; 32]; + fr.into_repr().write_be(&mut bytes[..]).unwrap(); + Self::from_bytes_be(bytes) + } +} + +impl Hasher for PoseidonHash { + type Hash = Hash; + + fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash { + POSEIDON + .hash(vec![left.into(), right.into()]) + .unwrap() // TODO + .into() + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use hex_literal::hex; + + #[test] + fn test_tree_4() { + const LEAF: Hash = Hash::from_bytes_be(hex!( + "0000000000000000000000000000000000000000000000000000000000000000" + )); + + let tree = PoseidonTree::new(3, LEAF); + assert_eq!(tree.num_leaves(), 4); + assert_eq!( + tree.root(), + Hash::from_bytes_be(hex!( + "1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1" + )) + ); + let proof = tree.proof(3).expect("proof should exist"); + assert_eq!( + proof, + crate::merkle_tree::Proof(vec![ + Branch::Right(LEAF), + Branch::Right(Hash::from_bytes_be(hex!( + "2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864" + ))), + ]) + ); + } +}