diff --git a/Cargo.toml b/Cargo.toml index 045a3af..29f6b86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ license-file = "mit-license.md" [features] default = [] bench = [ "criterion", "proptest" ] -mimc = [ "zkp-u256" ] +mimc = [] dylib = [ "wasmer/dylib", "wasmer-engine-dylib", "wasmer-compiler-cranelift" ] [[bench]] @@ -45,12 +45,12 @@ primitive-types = "0.11.1" proptest = { version = "1.0", optional = true } rand = "0.8.4" rayon = "1.5.1" +ruint = { version = "1.1.0", path = "../../WV/uint" } serde = "1.0" sha2 = "0.10.1" thiserror = "1.0.0" tiny-keccak = { version = "2.0.2" } wasmer = { version = "2.0" } -zkp-u256 = { version = "0.2", optional = true } # TODO: Remove # Use the same `ethers-core` version as ark-circom # TODO: Remove diff --git a/src/mimc_hash.rs b/src/mimc_hash.rs index 0eea675..6402a9b 100644 --- a/src/mimc_hash.rs +++ b/src/mimc_hash.rs @@ -10,16 +10,12 @@ use crate::util::keccak256; use once_cell::sync::Lazy; -use zkp_u256::U256; +use ruint::{aliases::U256, uint}; const NUM_ROUNDS: usize = 220; -static MODULUS: Lazy = Lazy::new(|| { - U256::from_decimal_str( - "21888242871839275222246405745257275088548364400416034343698204186575808495617", - ) - .unwrap() -}); +static MODULUS: U256 = + uint!(21888242871839275222246405745257275088548364400416034343698204186575808495617_U256); static ROUND_CONSTANTS: Lazy<[U256; NUM_ROUNDS]> = Lazy::new(|| { const SEED: &str = "mimcsponge"; @@ -27,24 +23,18 @@ static ROUND_CONSTANTS: Lazy<[U256; NUM_ROUNDS]> = Lazy::new(|| { let mut bytes = keccak256(SEED.as_bytes()); for constant in result[1..NUM_ROUNDS - 1].iter_mut() { bytes = keccak256(&bytes); - *constant = U256::from_bytes_be(&bytes); - *constant %= &*MODULUS; + *constant = U256::try_from_be_slice(&bytes).unwrap() % MODULUS; } result }); /// See fn mix(left: &mut U256, right: &mut U256) { - debug_assert!(*left < *MODULUS); - debug_assert!(*right < *MODULUS); - for round_constant in &*ROUND_CONSTANTS { + for round_constant in *ROUND_CONSTANTS { // Modulus is less than 2**252, so addition doesn't overflow - let t = (&*left + round_constant) % &*MODULUS; - let t2 = t.mulmod(&t, &*MODULUS); - let t4 = t2.mulmod(&t2, &*MODULUS); - let t5 = t.mulmod(&t4, &*MODULUS); - *right += t5; - *right %= &*MODULUS; + let t = left.add_mod(round_constant, MODULUS); + let t5 = t.pow_mod(U256::from(5), MODULUS); + *right = right.add_mod(t5, MODULUS); std::mem::swap(left, right); } std::mem::swap(left, right); @@ -54,10 +44,8 @@ fn mix(left: &mut U256, right: &mut U256) { pub fn hash(values: &[U256]) -> U256 { let mut left = U256::ZERO; let mut right = U256::ZERO; - for value in values { - let value = value % &*MODULUS; - left += value; - left %= &*MODULUS; + for &value in values { + left = left.add_mod(value, MODULUS); mix(&mut left, &mut right); } left @@ -74,62 +62,44 @@ pub mod test { assert_eq!(ROUND_CONSTANTS[0], U256::ZERO); assert_eq!( ROUND_CONSTANTS[1], - U256::from_decimal_str( - "7120861356467848435263064379192047478074060781135320967663101236819528304084" - ) - .unwrap() + uint!(7120861356467848435263064379192047478074060781135320967663101236819528304084_U256) ); assert_eq!( ROUND_CONSTANTS[2], - U256::from_decimal_str( - "5024705281721889198577876690145313457398658950011302225525409148828000436681" + uint!(5024705281721889198577876690145313457398658950011302225525409148828000436681_U256 ) - .unwrap() ); assert_eq!( ROUND_CONSTANTS[218], - U256::from_decimal_str( - "2119542016932434047340813757208803962484943912710204325088879681995922344971" + uint!(2119542016932434047340813757208803962484943912710204325088879681995922344971_U256 ) - .unwrap() ); assert_eq!(ROUND_CONSTANTS[219], U256::ZERO); } #[test] fn test_mix() { - let mut left = U256::ONE; + let mut left = U256::from(1); let mut right = U256::ZERO; mix(&mut left, &mut right); assert_eq!( left, - U256::from_decimal_str( - "8792246410719720074073794355580855662772292438409936688983564419486782556587" + uint!(8792246410719720074073794355580855662772292438409936688983564419486782556587_U256 ) - .unwrap() ); assert_eq!( right, - U256::from_decimal_str( - "7326554092124867281481480523863654579712861994895051796475958890524736238844" - ) - .unwrap() + uint!(7326554092124867281481480523863654579712861994895051796475958890524736238844_U256) ); left += U256::from(2); mix(&mut left, &mut right); assert_eq!( left, - U256::from_decimal_str( - "19814528709687996974327303300007262407299502847885145507292406548098437687919" - ) - .unwrap() + uint!(19814528709687996974327303300007262407299502847885145507292406548098437687919_U256) ); assert_eq!( right, - U256::from_decimal_str( - "3888906192024793285683241274210746486868893421288515595586335488978789653213" - ) - .unwrap() + uint!(3888906192024793285683241274210746486868893421288515595586335488978789653213_U256) ); } @@ -138,9 +108,7 @@ pub mod test { // See assert_eq!( hash(&[U256::from(1_u64), U256::from(2_u64)]), - U256::from_bytes_be(&hex!( - "2bcea035a1251603f1ceaf73cd4ae89427c47075bb8e3a944039ff1e3d6d2a6f" - )) + uint!(2bcea035a1251603f1ceaf73cd4ae89427c47075bb8e3a944039ff1e3d6d2a6f_U256) ); assert_eq!( hash(&[ @@ -149,9 +117,7 @@ pub mod test { U256::from(3_u64), U256::from(4_u64) ]), - U256::from_bytes_be(&hex!( - "03e86bdc4eac70bd601473c53d8233b145fe8fd8bf6ef25f0b217a1da305665c" - )) + uint!(03e86bdc4eac70bd601473c53d8233b145fe8fd8bf6ef25f0b217a1da305665c_U256) ); } } @@ -167,7 +133,7 @@ pub mod bench { } fn bench_mix(criterion: &mut Criterion) { - let mut left = U256::ONE; + let mut left = U256::from(1); let mut right = U256::ZERO; criterion.bench_function("mimc_mix", move |bencher| { bencher.iter(|| mix(&mut left, &mut right)); diff --git a/src/mimc_tree.rs b/src/mimc_tree.rs index f8e0248..1862796 100644 --- a/src/mimc_tree.rs +++ b/src/mimc_tree.rs @@ -3,8 +3,8 @@ use crate::{ merkle_tree::{self, Hasher, MerkleTree}, mimc_hash::hash, }; +use ruint::aliases::U256; use serde::Serialize; -use zkp_u256::U256; pub type MimcTree = MerkleTree; #[allow(dead_code)] @@ -19,9 +19,9 @@ impl Hasher for MimcHash { type Hash = Hash; fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash { - let left = U256::from_bytes_be(left.as_bytes_be()); - let right = U256::from_bytes_be(right.as_bytes_be()); - Hash::from_bytes_be(hash(&[left, right]).to_bytes_be()) + let left = U256::try_from_be_slice(left.as_bytes_be()).unwrap(); + let right = U256::try_from_be_slice(right.as_bytes_be()).unwrap(); + Hash::from_bytes_be(hash(&[left, right]).to_be_bytes()) } }