mirror of
https://github.com/vacp2p/semaphore-rs.git
synced 2025-02-23 09:08:28 +00:00
Merge pull request #10 from worldcoin/remco/concurrent-proof
Work around concurrent witness calculator bug
This commit is contained in:
commit
d861a73645
@ -6,6 +6,7 @@ use core::include_bytes;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::io::{Cursor, Write};
|
||||
use tempfile::NamedTempFile;
|
||||
use std::sync::Mutex;
|
||||
|
||||
const ZKEY_BYTES: &[u8] = include_bytes!("../semaphore/build/snark/semaphore_final.zkey");
|
||||
const WASM: &[u8] = include_bytes!("../semaphore/build/snark/semaphore.wasm");
|
||||
@ -15,7 +16,7 @@ pub static ZKEY: Lazy<(ProvingKey<Bn254>, ConstraintMatrices<Fr>)> = Lazy::new(|
|
||||
read_zkey(&mut reader).expect("zkey should be valid")
|
||||
});
|
||||
|
||||
pub static WITNESS_CALCULATOR: Lazy<WitnessCalculator> = Lazy::new(|| {
|
||||
pub static WITNESS_CALCULATOR: Lazy<Mutex<WitnessCalculator>> = Lazy::new(|| {
|
||||
// HACK: ark-circom requires a file, so we make one!
|
||||
let mut tmpfile = NamedTempFile::new().expect("Failed to create temp file");
|
||||
let written = tmpfile.write(WASM).expect("Failed to write to temp file");
|
||||
@ -23,5 +24,5 @@ pub static WITNESS_CALCULATOR: Lazy<WitnessCalculator> = Lazy::new(|| {
|
||||
let path = tmpfile.into_temp_path();
|
||||
let result = WitnessCalculator::new(&path).expect("Failed to create witness calculator");
|
||||
path.close().expect("Could not remove tempfile");
|
||||
result
|
||||
Mutex::new(result)
|
||||
});
|
||||
|
26
src/field.rs
26
src/field.rs
@ -1,7 +1,11 @@
|
||||
use crate::util::{bytes_from_hex, deserialize_bytes, keccak256, serialize_bytes};
|
||||
use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, keccak256, serialize_bytes};
|
||||
use ark_bn254::Fr as ArkField;
|
||||
use ark_ff::{BigInteger as _, PrimeField as _};
|
||||
use core::{str, str::FromStr};
|
||||
use core::{
|
||||
fmt::{Debug, Display},
|
||||
str,
|
||||
str::FromStr,
|
||||
};
|
||||
use ff::{PrimeField as _, PrimeFieldRepr as _};
|
||||
use num_bigint::{BigInt, Sign};
|
||||
use poseidon_rs::Fr as PosField;
|
||||
@ -10,7 +14,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
/// An element of the BN254 scalar field Fr.
|
||||
///
|
||||
/// Represented as a big-endian byte vector without Montgomery reduction.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
// TODO: Make sure value is always reduced.
|
||||
pub struct Field([u8; 32]);
|
||||
|
||||
@ -69,6 +73,22 @@ impl From<Field> for BigInt {
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Field {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let hex = bytes_to_hex::<32, 66>(&self.0);
|
||||
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
|
||||
write!(f, "Field({})", hex_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Field {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let hex = bytes_to_hex::<32, 66>(&self.0);
|
||||
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
|
||||
write!(f, "{}", hex_str)
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize a field element.
|
||||
///
|
||||
/// For human readable formats a `0x` prefixed lower case hex string is used.
|
||||
|
41
src/hash.rs
41
src/hash.rs
@ -1,11 +1,12 @@
|
||||
use crate::util::{bytes_from_hex, deserialize_bytes, serialize_bytes};
|
||||
use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, serialize_bytes};
|
||||
use core::{
|
||||
fmt::{Debug, Display},
|
||||
str,
|
||||
str::FromStr,
|
||||
};
|
||||
use ethers_core::types::U256;
|
||||
use num_bigint::{BigInt, Sign};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::{
|
||||
fmt::{Debug, Display, Formatter, Result as FmtResult},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
/// Container for 256-bit hash values.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Default)]
|
||||
@ -23,20 +24,6 @@ impl Hash {
|
||||
}
|
||||
}
|
||||
|
||||
/// Debug print hashes using `hex!(..)` literals.
|
||||
impl Debug for Hash {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
write!(f, "Hash(hex!(\"{}\"))", hex::encode(&self.0))
|
||||
}
|
||||
}
|
||||
|
||||
/// Display print hashes as `0x...`.
|
||||
impl Display for Hash {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
write!(f, "0x{}", hex::encode(&self.0))
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversion from Ether U256
|
||||
impl From<&Hash> for U256 {
|
||||
fn from(hash: &Hash) -> Self {
|
||||
@ -75,6 +62,22 @@ impl From<&Hash> for BigInt {
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Hash {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let hex = bytes_to_hex::<32, 66>(&self.0);
|
||||
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
|
||||
write!(f, "Field({})", hex_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Hash {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let hex = bytes_to_hex::<32, 66>(&self.0);
|
||||
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
|
||||
write!(f, "{}", hex_str)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse Hash from hex string.
|
||||
/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix
|
||||
/// but they must always be exactly 32 bytes.
|
||||
|
46
src/lib.rs
46
src/lib.rs
@ -39,6 +39,7 @@ mod test {
|
||||
protocol::{generate_nullifier_hash, generate_proof, verify_proof},
|
||||
Field,
|
||||
};
|
||||
use std::thread::spawn;
|
||||
|
||||
#[test]
|
||||
fn test_field_serde() {
|
||||
@ -48,15 +49,14 @@ mod test {
|
||||
assert_eq!(value, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_end_to_end() {
|
||||
fn test_end_to_end(identity: &[u8], external_nullifier: &[u8], signal: &[u8]) {
|
||||
// const LEAF: Hash = Hash::from_bytes_be(hex!(
|
||||
// "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
// ));
|
||||
let leaf = Field::from(0);
|
||||
|
||||
// generate identity
|
||||
let id = Identity::from_seed(b"hello");
|
||||
let id = Identity::from_seed(identity);
|
||||
|
||||
// generate merkle tree
|
||||
let mut tree = PoseidonTree::new(21, leaf);
|
||||
@ -64,10 +64,7 @@ mod test {
|
||||
|
||||
let merkle_proof = tree.proof(0).expect("proof should exist");
|
||||
let root = tree.root();
|
||||
|
||||
// change signal and external_nullifier here
|
||||
let signal = b"xxx";
|
||||
let external_nullifier = b"appId";
|
||||
dbg!(root);
|
||||
|
||||
let signal_hash = hash_to_field(signal);
|
||||
let external_nullifier_hash = hash_to_field(external_nullifier);
|
||||
@ -76,16 +73,33 @@ mod test {
|
||||
let proof =
|
||||
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap();
|
||||
|
||||
let success = verify_proof(
|
||||
root,
|
||||
nullifier_hash,
|
||||
signal_hash,
|
||||
external_nullifier_hash,
|
||||
&proof,
|
||||
)
|
||||
.unwrap();
|
||||
for _ in 0..5 {
|
||||
let success = verify_proof(
|
||||
root,
|
||||
nullifier_hash,
|
||||
signal_hash,
|
||||
external_nullifier_hash,
|
||||
&proof,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(success);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_single() {
|
||||
// Note that rust will still run tests in parallel
|
||||
test_end_to_end(b"hello", b"appId", b"xxx");
|
||||
}
|
||||
|
||||
assert!(success);
|
||||
#[test]
|
||||
fn test_parallel() {
|
||||
// Note that this does not guarantee a concurrency issue will be detected.
|
||||
// For that we need much more sophisticated static analysis tooling like
|
||||
// loom. See <https://github.com/tokio-rs/loom>
|
||||
let a = spawn(|| test_end_to_end(b"hello", b"appId", b"xxx"));
|
||||
let b = spawn(|| test_end_to_end(b"secret", b"test", b"signal"));
|
||||
a.join().unwrap();
|
||||
b.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,8 +116,10 @@ pub fn generate_proof(
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let full_assignment = WITNESS_CALCULATOR
|
||||
.clone()
|
||||
let full_assignment =
|
||||
WITNESS_CALCULATOR
|
||||
.lock()
|
||||
.expect("witness_calculator mutex should not get poisoned")
|
||||
.calculate_witness_element::<Bn254, _>(inputs, false)
|
||||
.map_err(ProofError::WitnessError)?;
|
||||
|
||||
@ -178,8 +180,7 @@ mod test {
|
||||
use super::*;
|
||||
use crate::{hash_to_field, poseidon_tree::PoseidonTree};
|
||||
|
||||
#[test]
|
||||
fn test_proof_serialize() {
|
||||
fn arb_proof() -> Proof {
|
||||
// generate identity
|
||||
let id = Identity::from_seed(b"secret");
|
||||
|
||||
@ -194,9 +195,20 @@ mod test {
|
||||
let signal_hash = hash_to_field(b"xxx");
|
||||
let external_nullifier_hash = hash_to_field(b"appId");
|
||||
|
||||
let proof =
|
||||
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap();
|
||||
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proof_cast_roundtrip() {
|
||||
let proof = arb_proof();
|
||||
let ark_proof: ArkProof<Bn<Parameters>> = proof.into();
|
||||
let result: Proof = ark_proof.into();
|
||||
assert_eq!(proof, result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proof_serialize() {
|
||||
let proof = arb_proof();
|
||||
let _json = serde_json::to_value(&proof).unwrap();
|
||||
|
||||
// TODO: Ideally we would check the output against an expected value,
|
||||
|
15
src/util.rs
15
src/util.rs
@ -16,6 +16,16 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] {
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_to_hex<const N: usize, const M: usize>(bytes: &[u8; N]) -> [u8; M] {
|
||||
// TODO: Replace `M` with a const expression once it's stable.
|
||||
debug_assert_eq!(M, 2 * N + 2);
|
||||
let mut result = [0u8; M];
|
||||
result[0] = b'0';
|
||||
result[1] = b'x';
|
||||
hex::encode_to_slice(&bytes[..], &mut result[2..]).expect("the buffer is correctly sized");
|
||||
result
|
||||
}
|
||||
|
||||
/// Helper to serialize byte arrays
|
||||
pub(crate) fn serialize_bytes<const N: usize, const M: usize, S: Serializer>(
|
||||
serializer: S,
|
||||
@ -25,10 +35,7 @@ pub(crate) fn serialize_bytes<const N: usize, const M: usize, S: Serializer>(
|
||||
debug_assert_eq!(M, 2 * N + 2);
|
||||
if serializer.is_human_readable() {
|
||||
// Write as a 0x prefixed lower-case hex string
|
||||
let mut buffer = [0u8; M];
|
||||
buffer[0] = b'0';
|
||||
buffer[1] = b'x';
|
||||
hex::encode_to_slice(&bytes[..], &mut buffer[2..]).expect("the buffer is correctly sized");
|
||||
let buffer = bytes_to_hex::<N, M>(bytes);
|
||||
let string = str::from_utf8(&buffer).expect("the buffer is valid UTF-8");
|
||||
serializer.serialize_str(string)
|
||||
} else {
|
||||
|
Loading…
x
Reference in New Issue
Block a user