poseidon tree

This commit is contained in:
psippl 2022-01-29 12:52:32 +01:00
parent 997a4ec44b
commit 4d8b87364c
9 changed files with 2450 additions and 78 deletions

1811
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -15,5 +15,9 @@ ff = { package="ff_ce", version="0.11"}
poseidon-rs = "0.0.8"
color-eyre = "0.5"
sha2 = "0.10.1"
hex = "0.3.1"
once_cell = "1.8"
hex = "0.4.0"
once_cell = "1.8"
serde = "1.0"
ethers = "0.6"
hex-literal = "0.3"
proptest = { version = "1.0", optional = true }

202
src/hash.rs Normal file
View File

@ -0,0 +1,202 @@
use ethers::types::U256;
use serde::{
de::{Error as DeError, Visitor},
ser::Error as _,
Deserialize, Serialize,
};
use std::{
fmt::{Debug, Display, Formatter, Result as FmtResult},
str::{from_utf8, FromStr},
};
/// Container for 256-bit hash values.
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct Hash(pub [u8; 32]);
impl Hash {
#[must_use]
pub const fn from_bytes_be(bytes: [u8; 32]) -> Self {
Self(bytes)
}
#[must_use]
pub const fn as_bytes_be(&self) -> &[u8; 32] {
&self.0
}
}
/// Debug print hashes using `hex!(..)` literals.
impl Debug for Hash {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "Hash(hex!(\"{}\"))", hex::encode(&self.0))
}
}
/// Display print hashes as `0x...`.
impl Display for Hash {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "0x{}", hex::encode(&self.0))
}
}
/// Conversion from Ether U256
impl From<&Hash> for U256 {
fn from(hash: &Hash) -> Self {
Self::from_big_endian(hash.as_bytes_be())
}
}
/// Conversion to Ether U256
impl From<U256> for Hash {
fn from(u256: U256) -> Self {
let mut bytes = [0_u8; 32];
u256.to_big_endian(&mut bytes);
Self::from_bytes_be(bytes)
}
}
/// Parse Hash from hex string.
/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix
/// but they must always be exactly 32 bytes.
impl FromStr for Hash {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let str = trim_hex_prefix(s);
let mut out = [0_u8; 32];
hex::decode_to_slice(str, &mut out)?;
Ok(Self(out))
}
}
/// Serialize hashes into human readable hex strings or byte arrays.
/// Hex strings are lower case without prefix and always 32 bytes.
impl Serialize for Hash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
let mut hex_ascii = [0_u8; 64];
hex::encode_to_slice(self.0, &mut hex_ascii)
.map_err(|e| S::Error::custom(format!("Error hex encoding: {}", e)))?;
from_utf8(&hex_ascii)
.map_err(|e| S::Error::custom(format!("Invalid hex encoding: {}", e)))?
.serialize(serializer)
} else {
self.0.serialize(serializer)
}
}
}
/// Deserialize human readable hex strings or byte arrays into hashes.
/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix
/// but they must always be exactly 32 bytes.
impl<'de> Deserialize<'de> for Hash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(HashStrVisitor)
} else {
<[u8; 32]>::deserialize(deserializer).map(Hash)
}
}
}
struct HashStrVisitor;
impl<'de> Visitor<'de> for HashStrVisitor {
type Value = Hash;
fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
formatter.write_str("a 32 byte hex string")
}
fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Self::Value, E>
where
E: DeError,
{
Hash::from_str(value).map_err(|e| E::custom(format!("Error in hex: {}", e)))
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: DeError,
{
Hash::from_str(value).map_err(|e| E::custom(format!("Error in hex: {}", e)))
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: DeError,
{
Hash::from_str(&value).map_err(|e| E::custom(format!("Error in hex: {}", e)))
}
}
/// Helper function to optionally remove `0x` prefix from hex strings.
fn trim_hex_prefix(str: &str) -> &str {
if str.len() >= 2 && (&str[..2] == "0x" || &str[..2] == "0X") {
&str[2..]
} else {
str
}
}
#[cfg(test)]
pub mod test {
use super::*;
use hex_literal::hex;
use proptest::proptest;
use serde_json::{from_str, to_string};
#[test]
fn test_serialize() {
let hash = Hash([0; 32]);
assert_eq!(
to_string(&hash).unwrap(),
"\"0000000000000000000000000000000000000000000000000000000000000000\""
);
let hash = Hash(hex!(
"1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe"
));
assert_eq!(
to_string(&hash).unwrap(),
"\"1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe\""
);
}
#[test]
fn test_deserialize() {
assert_eq!(
from_str::<Hash>(
"\"0x1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe\""
)
.unwrap(),
Hash(hex!(
"1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe"
))
);
assert_eq!(
from_str::<Hash>(
"\"0X1C4823575d154474EE3e5ac838d002456a815181437afd14f126da58a9912bbe\""
)
.unwrap(),
Hash(hex!(
"1c4823575d154474ee3e5ac838d002456a815181437afd14f126da58a9912bbe"
))
);
}
#[test]
fn test_roundtrip() {
proptest!(|(bytes: [u8; 32])| {
let hash = Hash(bytes);
let json = to_string(&hash).unwrap();
let parsed = from_str(&json).unwrap();
assert_eq!(hash, parsed);
});
}
}

View File

@ -5,24 +5,10 @@ use once_cell::sync::Lazy;
use poseidon_rs::{Fr, FrRepr, Poseidon};
use sha2::{Digest, Sha256};
use crate::{hash::Hash};
static POSEIDON: Lazy<Poseidon> = Lazy::new(Poseidon::new);
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Identity {
identity_trapdoor: BigInt,
identity_nullifier: BigInt,
}
// todo: improve
fn sha(msg: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(msg);
let result = hasher.finalize();
let res: [u8; 32] = result.into();
res
}
// todo: improve
fn bigint_to_fr(bi: &BigInt) -> Fr {
// dirty: have to force the point into the field manually, otherwise you get an error if bi not in field
let q = BigInt::parse_bytes(
@ -44,6 +30,21 @@ fn fr_to_bigint(fr: Fr) -> BigInt {
BigInt::from_bytes_be(Sign::Plus, &bytes)
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Identity {
identity_trapdoor: BigInt,
identity_nullifier: BigInt,
}
// todo: improve
fn sha(msg: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(msg);
let result = hasher.finalize();
let res: [u8; 32] = result.into();
res
}
impl Identity {
pub fn new(seed: &[u8]) -> Self {
let seed_hash = &sha(seed);
@ -80,4 +81,12 @@ impl Identity {
.unwrap();
fr_to_bigint(res)
}
pub fn identity_commitment_leaf(&self) -> Hash {
let res = POSEIDON
.hash(vec![bigint_to_fr(&self.identity_commitment())])
.unwrap();
res.into()
}
}

View File

@ -1,6 +1,11 @@
mod identity;
mod proof;
mod merkle_tree;
mod poseidon_tree;
mod hash;
use std::os::raw::{c_char};
use std::ffi::{CString, CStr};
mod identity;
#[no_mangle]
pub extern fn generate_identity_commitment(seed: *const c_char) -> *mut c_char {

View File

@ -1,9 +1,32 @@
mod proof;
mod identity;
mod merkle_tree;
mod poseidon_tree;
mod hash;
use poseidon_rs::Poseidon;
use hex_literal::hex;
use {identity::*, poseidon_tree::*, hash::*};
fn main() {
// proof::Proof_signal().unwrap();
let id = identity::Identity::new(b"hello");
// generate identity
let id = Identity::new(b"hello");
dbg!(&id);
dbg!(id.identity_commitment());
// generate merkle tree
const LEAF: Hash = Hash::from_bytes_be(hex!(
"0000000000000000000000000000000000000000000000000000000000000000"
));
let mut tree = PoseidonTree::new(3, LEAF);
tree.set(0, id.identity_commitment_leaf());
dbg!(tree.root());
let proof = tree.proof(0).expect("proof should exist");
dbg!(proof);
}

335
src/merkle_tree.rs Normal file
View File

@ -0,0 +1,335 @@
//! Implements basic binary Merkle trees
//!
//! # To do
//!
//! * Disk based storage backend (using mmaped files should be easy)
use std::{
fmt::Debug,
iter::{once, repeat, successors},
};
use serde::Serialize;
/// Hash types, values and algorithms for a Merkle tree
pub trait Hasher {
/// Type of the leaf and node hashes
type Hash: Clone + Eq + Serialize;
/// Compute the hash of an intermediate node
fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash;
}
/// Merkle tree with all leaf and intermediate hashes stored
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct MerkleTree<H: Hasher> {
/// Depth of the tree, # of layers including leaf layer
depth: usize,
/// Hash value of empty subtrees of given depth, starting at leaf level
empty: Vec<H::Hash>,
/// Hash values of tree nodes and leaves, breadth first order
nodes: Vec<H::Hash>,
}
/// Element of a Merkle proof
#[derive(Clone, Copy, PartialEq, Eq, Serialize)]
pub enum Branch<H: Hasher> {
/// Left branch taken, value is the right sibling hash.
Left(H::Hash),
/// Right branch taken, value is the left sibling hash.
Right(H::Hash),
}
/// Merkle proof path, bottom to top.
#[derive(Clone, PartialEq, Eq, Serialize)]
pub struct Proof<H: Hasher>(pub Vec<Branch<H>>);
/// For a given node index, return the parent node index
/// Returns None if there is no parent (root node)
const fn parent(index: usize) -> Option<usize> {
if index == 0 {
None
} else {
Some(((index + 1) >> 1) - 1)
}
}
/// For a given node index, return index of the first (left) child.
const fn first_child(index: usize) -> usize {
(index << 1) + 1
}
const fn depth(index: usize) -> usize {
// `n.next_power_of_two()` will return `n` iff `n` is a power of two.
// The extra offset corrects this.
(index + 2).next_power_of_two().trailing_zeros() as usize - 1
}
impl<H: Hasher> MerkleTree<H> {
/// Creates a new `MerkleTree`
/// * `depth` - The depth of the tree, including the root. This is 1 greater
/// than the `treeLevels` argument to the Semaphore contract.
pub fn new(depth: usize, initial_leaf: H::Hash) -> Self {
// Compute empty node values, leaf to root
let empty = successors(Some(initial_leaf), |prev| Some(H::hash_node(prev, prev)))
.take(depth)
.collect::<Vec<_>>();
// Compute node values
let nodes = empty
.iter()
.rev()
.enumerate()
.flat_map(|(depth, hash)| repeat(hash).take(1 << depth))
.cloned()
.collect::<Vec<_>>();
debug_assert!(nodes.len() == (1 << depth) - 1);
Self {
depth,
empty,
nodes,
}
}
pub fn num_leaves(&self) -> usize {
self.depth
.checked_sub(1)
.map(|n| 1 << n)
.unwrap_or_default()
}
pub fn root(&self) -> H::Hash {
self.nodes[0].clone()
}
pub fn set(&mut self, leaf: usize, hash: H::Hash) {
self.set_range(leaf, once(hash));
}
pub fn set_range<I: IntoIterator<Item = H::Hash>>(&mut self, start: usize, hashes: I) {
let index = self.num_leaves() + start - 1;
let mut count = 0;
// TODO: Error/panic when hashes is longer than available leafs
for (leaf, hash) in self.nodes[index..].iter_mut().zip(hashes) {
*leaf = hash;
count += 1;
}
if count != 0 {
self.update_nodes(index, index + (count - 1));
}
}
fn update_nodes(&mut self, start: usize, end: usize) {
debug_assert_eq!(depth(start), depth(end));
if let (Some(start), Some(end)) = (parent(start), parent(end)) {
for parent in start..=end {
let child = first_child(parent);
self.nodes[parent] = H::hash_node(&self.nodes[child], &self.nodes[child + 1]);
}
self.update_nodes(start, end);
}
}
pub fn proof(&self, leaf: usize) -> Option<Proof<H>> {
if leaf >= self.num_leaves() {
return None;
}
let mut index = self.num_leaves() + leaf - 1;
let mut path = Vec::with_capacity(self.depth);
while let Some(parent) = parent(index) {
// Add proof for node at index to parent
path.push(match index & 1 {
1 => Branch::Left(self.nodes[index + 1].clone()),
0 => Branch::Right(self.nodes[index - 1].clone()),
_ => unreachable!(),
});
index = parent;
}
Some(Proof(path))
}
#[allow(dead_code)]
pub fn verify(&self, hash: H::Hash, proof: &Proof<H>) -> bool {
proof.root(hash) == self.root()
}
pub fn leaves(&self) -> &[H::Hash] {
&self.nodes[(self.num_leaves() - 1)..]
}
}
impl<H: Hasher> Proof<H> {
/// Compute the leaf index for this proof
#[allow(dead_code)]
pub fn leaf_index(&self) -> usize {
self.0.iter().rev().fold(0, |index, branch| match branch {
Branch::Left(_) => index << 1,
Branch::Right(_) => (index << 1) + 1,
})
}
/// Compute the Merkle root given a leaf hash
#[allow(dead_code)]
pub fn root(&self, hash: H::Hash) -> H::Hash {
self.0.iter().fold(hash, |hash, branch| match branch {
Branch::Left(sibling) => H::hash_node(&hash, sibling),
Branch::Right(sibling) => H::hash_node(sibling, &hash),
})
}
}
impl<H> Debug for Branch<H>
where
H: Hasher,
H::Hash: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Left(arg0) => f.debug_tuple("Left").field(arg0).finish(),
Self::Right(arg0) => f.debug_tuple("Right").field(arg0).finish(),
}
}
}
impl<H> Debug for Proof<H>
where
H: Hasher,
H::Hash: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Proof").field(&self.0).finish()
}
}
#[cfg(test)]
pub mod test {
use super::*;
use ethers::utils::keccak256;
use hex_literal::hex;
struct Keccak;
impl Hasher for Keccak {
type Hash = [u8; 32];
fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash {
keccak256([*left, *right].concat())
}
}
#[test]
fn test_index_calculus() {
assert_eq!(parent(0), None);
assert_eq!(parent(1), Some(0));
assert_eq!(parent(2), Some(0));
assert_eq!(parent(3), Some(1));
assert_eq!(parent(4), Some(1));
assert_eq!(parent(5), Some(2));
assert_eq!(parent(6), Some(2));
assert_eq!(first_child(0), 1);
assert_eq!(first_child(2), 5);
assert_eq!(depth(0), 0);
assert_eq!(depth(1), 1);
assert_eq!(depth(2), 1);
assert_eq!(depth(3), 2);
assert_eq!(depth(6), 2);
}
#[test]
fn test_root() {
let mut tree = MerkleTree::<Keccak>::new(3, [0; 32]);
assert_eq!(
tree.root(),
hex!("b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30")
);
tree.set(
0,
hex!("0000000000000000000000000000000000000000000000000000000000000001"),
);
assert_eq!(
tree.root(),
hex!("c1ba1812ff680ce84c1d5b4f1087eeb08147a4d510f3496b2849df3a73f5af95")
);
tree.set(
1,
hex!("0000000000000000000000000000000000000000000000000000000000000002"),
);
assert_eq!(
tree.root(),
hex!("893760ec5b5bee236f29e85aef64f17139c3c1b7ff24ce64eb6315fca0f2485b")
);
tree.set(
2,
hex!("0000000000000000000000000000000000000000000000000000000000000003"),
);
assert_eq!(
tree.root(),
hex!("222ff5e0b5877792c2bc1670e2ccd0c2c97cd7bb1672a57d598db05092d3d72c")
);
tree.set(
3,
hex!("0000000000000000000000000000000000000000000000000000000000000004"),
);
assert_eq!(
tree.root(),
hex!("a9bb8c3f1f12e9aa903a50c47f314b57610a3ab32f2d463293f58836def38d36")
);
}
#[test]
fn test_proof() {
let mut tree = MerkleTree::<Keccak>::new(3, [0; 32]);
tree.set(
0,
hex!("0000000000000000000000000000000000000000000000000000000000000001"),
);
tree.set(
1,
hex!("0000000000000000000000000000000000000000000000000000000000000002"),
);
tree.set(
2,
hex!("0000000000000000000000000000000000000000000000000000000000000003"),
);
tree.set(
3,
hex!("0000000000000000000000000000000000000000000000000000000000000004"),
);
let proof = tree.proof(2).expect("proof should exist");
assert_eq!(proof.leaf_index(), 2);
assert!(tree.verify(
hex!("0000000000000000000000000000000000000000000000000000000000000003"),
&proof
));
assert!(!tree.verify(
hex!("0000000000000000000000000000000000000000000000000000000000000001"),
&proof
));
}
#[test]
fn test_position() {
let mut tree = MerkleTree::<Keccak>::new(3, [0; 32]);
tree.set(
0,
hex!("0000000000000000000000000000000000000000000000000000000000000001"),
);
tree.set(
1,
hex!("0000000000000000000000000000000000000000000000000000000000000002"),
);
tree.set(
2,
hex!("0000000000000000000000000000000000000000000000000000000000000003"),
);
tree.set(
3,
hex!("0000000000000000000000000000000000000000000000000000000000000004"),
);
}
}

81
src/poseidon_tree.rs Normal file
View File

@ -0,0 +1,81 @@
use crate::{
hash::Hash,
merkle_tree::{self, Hasher, MerkleTree},
};
use ff::{PrimeField, PrimeFieldRepr};
use once_cell::sync::Lazy;
use poseidon_rs::{Fr, FrRepr, Poseidon};
use serde::Serialize;
static POSEIDON: Lazy<Poseidon> = Lazy::new(Poseidon::new);
#[allow(dead_code)]
pub type PoseidonTree = MerkleTree<PoseidonHash>;
#[allow(dead_code)]
pub type Branch = merkle_tree::Branch<PoseidonHash>;
#[allow(dead_code)]
pub type Proof = merkle_tree::Proof<PoseidonHash>;
#[derive(Clone, Copy, PartialEq, Eq, Serialize)]
pub struct PoseidonHash;
#[allow(clippy::fallible_impl_from)] // TODO
impl From<&Hash> for Fr {
fn from(hash: &Hash) -> Self {
let mut repr = FrRepr::default();
repr.read_be(&hash.as_bytes_be()[..]).unwrap();
Self::from_repr(repr).unwrap()
}
}
#[allow(clippy::fallible_impl_from)] // TODO
impl From<Fr> for Hash {
fn from(fr: Fr) -> Self {
let mut bytes = [0_u8; 32];
fr.into_repr().write_be(&mut bytes[..]).unwrap();
Self::from_bytes_be(bytes)
}
}
impl Hasher for PoseidonHash {
type Hash = Hash;
fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash {
POSEIDON
.hash(vec![left.into(), right.into()])
.unwrap() // TODO
.into()
}
}
#[cfg(test)]
pub mod test {
use super::*;
use hex_literal::hex;
#[test]
fn test_tree_4() {
const LEAF: Hash = Hash::from_bytes_be(hex!(
"0000000000000000000000000000000000000000000000000000000000000000"
));
let tree = PoseidonTree::new(3, LEAF);
assert_eq!(tree.num_leaves(), 4);
assert_eq!(
tree.root(),
Hash::from_bytes_be(hex!(
"1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1"
))
);
let proof = tree.proof(3).expect("proof should exist");
assert_eq!(
proof,
crate::merkle_tree::Proof(vec![
Branch::Right(LEAF),
Branch::Right(Hash::from_bytes_be(hex!(
"2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864"
))),
])
);
}
}

View File

@ -3,19 +3,29 @@ use ark_std::rand::thread_rng;
use ark_bn254::Bn254;
use color_eyre::Result;
use crate::identity::*;
use ark_groth16::{
create_random_proof as prove, generate_random_parameters, prepare_verifying_key, verify_proof,
};
// WIP: uses dummy proofs for now
fn proof_signal() -> Result<()> {
fn proof_signal(identity: Identity) -> Result<()> {
let cfg = CircomConfig::<Bn254>::new(
"./snarkfiles/circom2_multiplier2.wasm",
"./snarkfiles/circom2_multiplier2.r1cs",
)?;
// identity_nullifier: identityNullifier,
// identity_trapdoor: identityTrapdoor,
// identity_path_index: merkleProof.pathIndices,
// path_elements: merkleProof.siblings,
// external_nullifier: externalNullifier,
// signal_hash: shouldHash ? genSignalHash(signal) : signal
let mut builder = CircomBuilder::new(cfg);
builder.push_input("a", 3);
builder.push_input("b", 11);
// builder.push_input("a", 3);
// builder.push_input("b", 11);
// create an empty instance for setting it up
let circom = builder.setup();