diff --git a/Cargo.toml b/Cargo.toml index 29f6b86..d982abf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ primitive-types = "0.11.1" proptest = { version = "1.0", optional = true } rand = "0.8.4" rayon = "1.5.1" -ruint = { version = "1.1.0", path = "../../WV/uint" } +ruint = { version = "1.1.0", path = "../../WV/uint", features = [ "serde", "poseidon-rs", "num-bigint", "ark-ff" ] } serde = "1.0" sha2 = "0.10.1" thiserror = "1.0.0" diff --git a/src/field.rs b/src/field.rs index 5378f0e..07a7ade 100644 --- a/src/field.rs +++ b/src/field.rs @@ -9,126 +9,18 @@ use core::{ use ff::{PrimeField as _, PrimeFieldRepr as _}; use num_bigint::{BigInt, Sign}; use poseidon_rs::Fr as PosField; +use ruint::{aliases::U256, uint}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// An element of the BN254 scalar field Fr. /// /// Represented as a big-endian byte vector without Montgomery reduction. -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] // TODO: Make sure value is always reduced. -pub struct Field([u8; 32]); +pub type Field = U256; -impl Field { - /// Construct a field element from a big-endian byte vector. - #[must_use] - pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Self { - ArkField::from_be_bytes_mod_order(bytes).into() - } - - /// Convert to big-endian 32-byte array. - #[must_use] - pub const fn to_be_bytes(&self) -> [u8; 32] { - self.0 - } -} - -impl From for Field { - fn from(value: u64) -> Self { - ArkField::from(value).into() - } -} - -impl From for Field { - fn from(value: ArkField) -> Self { - let mut bytes = [0_u8; 32]; - let byte_vec = value.into_repr().to_bytes_be(); - bytes.copy_from_slice(&byte_vec[..]); - Self(bytes) - } -} - -impl From for ArkField { - fn from(value: Field) -> Self { - Self::from_be_bytes_mod_order(&value.0[..]) - } -} - -impl From for Field { - fn from(value: PosField) -> Self { - let mut bytes = [0u8; 32]; - value - .into_repr() - .write_be(&mut bytes[..]) - .expect("write to correctly sized slice always succeeds"); - Self(bytes) - } -} - -impl From for PosField { - fn from(value: Field) -> Self { - let mut repr = ::Repr::default(); - repr.read_be(&value.0[..]) - .expect("read from correctly sized slice always succeeds"); - Self::from_repr(repr).expect("value is always in range") - } -} - -impl From for BigInt { - fn from(value: Field) -> Self { - Self::from_bytes_be(Sign::Plus, &value.0[..]) - } -} - -impl Debug for Field { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let hex = bytes_to_hex::<32, 66>(&self.0); - let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8"); - write!(f, "Field({})", hex_str) - } -} - -impl Display for Field { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let hex = bytes_to_hex::<32, 66>(&self.0); - let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8"); - write!(f, "{}", hex_str) - } -} - -/// Serialize a field element. -/// -/// For human readable formats a `0x` prefixed lower case hex string is used. -/// For binary formats a byte array is used. -impl Serialize for Field { - fn serialize(&self, serializer: S) -> Result { - serialize_bytes::<32, 66, S>(serializer, &self.0) - } -} - -/// Parse Hash from hex string. -/// -/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix -/// but they must always be exactly 32 bytes. -/// -/// Too large values are reduced modulo the field prime. -impl FromStr for Field { - type Err = hex::FromHexError; - - fn from_str(s: &str) -> Result { - let bytes = bytes_from_hex::<32>(s)?; - Ok(Self::from_be_bytes_mod_order(&bytes[..])) - } -} - -/// Deserialize human readable hex strings or byte arrays into hashes. -/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix -/// but they must always be exactly 32 bytes. -impl<'de> Deserialize<'de> for Field { - fn deserialize>(deserializer: D) -> Result { - let bytes = deserialize_bytes::<32, _>(deserializer)?; - Ok(Self::from_be_bytes_mod_order(&bytes)) - } -} +// See +pub const MODULUS: Field = + uint!(21888242871839275222246405745257275088548364400416034343698204186575808495617_U256); /// Hash arbitrary data to a field element. /// @@ -136,11 +28,9 @@ impl<'de> Deserialize<'de> for Field { #[must_use] #[allow(clippy::module_name_repetitions)] pub fn hash_to_field(data: &[u8]) -> Field { - let hash = keccak256(data); + let n = U256::try_from_be_slice(&keccak256(data)).unwrap(); // Shift right one byte to make it fit in the field - let mut bytes = [0_u8; 32]; - bytes[1..].copy_from_slice(&hash[..31]); - Field(bytes) + n >> 8 } #[cfg(test)] diff --git a/src/identity.rs b/src/identity.rs index 89eedd9..df2dd78 100644 --- a/src/identity.rs +++ b/src/identity.rs @@ -1,4 +1,4 @@ -use crate::{poseidon_hash, Field}; +use crate::{field::MODULUS, poseidon_hash, Field}; use sha2::{Digest, Sha256}; #[derive(Clone, PartialEq, Eq, Debug)] @@ -14,7 +14,7 @@ fn derive_field(seed_hex: &[u8; 64], suffix: &[u8]) -> Field { let mut hasher = Sha256::new(); hasher.update(seed_hex); hasher.update(suffix); - Field::from_be_bytes_mod_order(hasher.finalize().as_ref()) + Field::try_from_be_slice(hasher.finalize().as_ref()).unwrap() % MODULUS } fn seed_hex(seed: &[u8]) -> [u8; 64] { diff --git a/src/lib.rs b/src/lib.rs index 7e30d6f..c103974 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,7 +67,6 @@ mod test { let merkle_proof = tree.proof(0).expect("proof should exist"); let root = tree.root(); - dbg!(root); let signal_hash = hash_to_field(signal); let external_nullifier_hash = hash_to_field(external_nullifier); diff --git a/src/mimc_hash.rs b/src/mimc_hash.rs index 6402a9b..9a88ced 100644 --- a/src/mimc_hash.rs +++ b/src/mimc_hash.rs @@ -54,7 +54,6 @@ pub fn hash(values: &[U256]) -> U256 { #[cfg(test)] pub mod test { use super::*; - use hex_literal::hex; #[test] fn test_round_constants() { @@ -108,7 +107,7 @@ pub mod test { // See assert_eq!( hash(&[U256::from(1_u64), U256::from(2_u64)]), - uint!(2bcea035a1251603f1ceaf73cd4ae89427c47075bb8e3a944039ff1e3d6d2a6f_U256) + uint!(0x2bcea035a1251603f1ceaf73cd4ae89427c47075bb8e3a944039ff1e3d6d2a6f_U256) ); assert_eq!( hash(&[ @@ -117,7 +116,7 @@ pub mod test { U256::from(3_u64), U256::from(4_u64) ]), - uint!(03e86bdc4eac70bd601473c53d8233b145fe8fd8bf6ef25f0b217a1da305665c_U256) + uint!(0x03e86bdc4eac70bd601473c53d8233b145fe8fd8bf6ef25f0b217a1da305665c_U256) ); } } diff --git a/src/poseidon_hash.rs b/src/poseidon_hash.rs index 929f953..6648ca3 100644 --- a/src/poseidon_hash.rs +++ b/src/poseidon_hash.rs @@ -6,10 +6,20 @@ static POSEIDON: Lazy = Lazy::new(Poseidon::new); #[must_use] pub fn poseidon_hash(input: &[Field]) -> Field { - let input = input.iter().copied().map(Into::into).collect::>(); + let input = input.iter().map(Into::into).collect::>(); POSEIDON .hash(input) .map(Into::into) .expect("hash with fixed input size can't fail") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty() { + assert_eq!(poseidon_hash(&[]), Field::ZERO); + } +} diff --git a/src/protocol.rs b/src/protocol.rs index beb84d0..9a88ad2 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -6,7 +6,7 @@ use crate::{ poseidon_tree::PoseidonHash, Field, }; -use ark_bn254::{Bn254, Parameters}; +use ark_bn254::{Bn254, Fr, Parameters}; use ark_circom::CircomReduction; use ark_ec::bn::Bn; use ark_groth16::{ @@ -87,6 +87,8 @@ pub enum ProofError { WitnessError(color_eyre::Report), #[error("Error producing proof: {0}")] SynthesisError(#[from] SynthesisError), + #[error("Error converting public input: {0}")] + ToFieldError(#[from] ruint::ToFieldError), } /// Generates a semaphore proof @@ -150,7 +152,7 @@ fn generate_proof_rs( let inputs = inputs.into_iter().map(|(name, values)| { ( name.to_string(), - values.iter().copied().map(Into::into).collect::>(), + values.iter().map(Into::into).collect::>(), ) }); @@ -197,12 +199,11 @@ pub fn verify_proof( let zkey = zkey(); let pvk = prepare_verifying_key(&zkey.0.vk); - let public_inputs = [ - root.into(), - nullifier_hash.into(), - signal_hash.into(), - external_nullifier_hash.into(), - ]; + let public_inputs = [root, nullifier_hash, signal_hash, external_nullifier_hash] + .iter() + .map(|n| Fr::try_from(n)) + .collect::, _>>()?; + let ark_proof = (*proof).into(); let result = ark_groth16::verify_proof(&pvk, &ark_proof, &public_inputs[..])?; Ok(result)