Merge pull request #10 from worldcoin/remco/concurrent-proof

Work around concurrent witness calculator bug
This commit is contained in:
Philipp Sippl 2022-03-21 12:21:52 +01:00 committed by GitHub
commit d861a73645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 50 deletions

View File

@ -6,6 +6,7 @@ use core::include_bytes;
use once_cell::sync::Lazy;
use std::io::{Cursor, Write};
use tempfile::NamedTempFile;
use std::sync::Mutex;
const ZKEY_BYTES: &[u8] = include_bytes!("../semaphore/build/snark/semaphore_final.zkey");
const WASM: &[u8] = include_bytes!("../semaphore/build/snark/semaphore.wasm");
@ -15,7 +16,7 @@ pub static ZKEY: Lazy<(ProvingKey<Bn254>, ConstraintMatrices<Fr>)> = Lazy::new(|
read_zkey(&mut reader).expect("zkey should be valid")
});
pub static WITNESS_CALCULATOR: Lazy<WitnessCalculator> = Lazy::new(|| {
pub static WITNESS_CALCULATOR: Lazy<Mutex<WitnessCalculator>> = Lazy::new(|| {
// HACK: ark-circom requires a file, so we make one!
let mut tmpfile = NamedTempFile::new().expect("Failed to create temp file");
let written = tmpfile.write(WASM).expect("Failed to write to temp file");
@ -23,5 +24,5 @@ pub static WITNESS_CALCULATOR: Lazy<WitnessCalculator> = Lazy::new(|| {
let path = tmpfile.into_temp_path();
let result = WitnessCalculator::new(&path).expect("Failed to create witness calculator");
path.close().expect("Could not remove tempfile");
result
Mutex::new(result)
});

View File

@ -1,7 +1,11 @@
use crate::util::{bytes_from_hex, deserialize_bytes, keccak256, serialize_bytes};
use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, keccak256, serialize_bytes};
use ark_bn254::Fr as ArkField;
use ark_ff::{BigInteger as _, PrimeField as _};
use core::{str, str::FromStr};
use core::{
fmt::{Debug, Display},
str,
str::FromStr,
};
use ff::{PrimeField as _, PrimeFieldRepr as _};
use num_bigint::{BigInt, Sign};
use poseidon_rs::Fr as PosField;
@ -10,7 +14,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// An element of the BN254 scalar field Fr.
///
/// Represented as a big-endian byte vector without Montgomery reduction.
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
// TODO: Make sure value is always reduced.
pub struct Field([u8; 32]);
@ -69,6 +73,22 @@ impl From<Field> for BigInt {
}
}
impl Debug for Field {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "Field({})", hex_str)
}
}
impl Display for Field {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "{}", hex_str)
}
}
/// Serialize a field element.
///
/// For human readable formats a `0x` prefixed lower case hex string is used.

View File

@ -1,11 +1,12 @@
use crate::util::{bytes_from_hex, deserialize_bytes, serialize_bytes};
use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, serialize_bytes};
use core::{
fmt::{Debug, Display},
str,
str::FromStr,
};
use ethers_core::types::U256;
use num_bigint::{BigInt, Sign};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
fmt::{Debug, Display, Formatter, Result as FmtResult},
str::FromStr,
};
/// Container for 256-bit hash values.
#[derive(Clone, Copy, PartialEq, Eq, Default)]
@ -23,20 +24,6 @@ impl Hash {
}
}
/// Debug print hashes using `hex!(..)` literals.
impl Debug for Hash {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "Hash(hex!(\"{}\"))", hex::encode(&self.0))
}
}
/// Display print hashes as `0x...`.
impl Display for Hash {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "0x{}", hex::encode(&self.0))
}
}
/// Conversion from Ether U256
impl From<&Hash> for U256 {
fn from(hash: &Hash) -> Self {
@ -75,6 +62,22 @@ impl From<&Hash> for BigInt {
}
}
impl Debug for Hash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "Field({})", hex_str)
}
}
impl Display for Hash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "{}", hex_str)
}
}
/// Parse Hash from hex string.
/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix
/// but they must always be exactly 32 bytes.

View File

@ -39,6 +39,7 @@ mod test {
protocol::{generate_nullifier_hash, generate_proof, verify_proof},
Field,
};
use std::thread::spawn;
#[test]
fn test_field_serde() {
@ -48,15 +49,14 @@ mod test {
assert_eq!(value, deserialized);
}
#[test]
fn test_end_to_end() {
fn test_end_to_end(identity: &[u8], external_nullifier: &[u8], signal: &[u8]) {
// const LEAF: Hash = Hash::from_bytes_be(hex!(
// "0000000000000000000000000000000000000000000000000000000000000000"
// ));
let leaf = Field::from(0);
// generate identity
let id = Identity::from_seed(b"hello");
let id = Identity::from_seed(identity);
// generate merkle tree
let mut tree = PoseidonTree::new(21, leaf);
@ -64,10 +64,7 @@ mod test {
let merkle_proof = tree.proof(0).expect("proof should exist");
let root = tree.root();
// change signal and external_nullifier here
let signal = b"xxx";
let external_nullifier = b"appId";
dbg!(root);
let signal_hash = hash_to_field(signal);
let external_nullifier_hash = hash_to_field(external_nullifier);
@ -76,16 +73,33 @@ mod test {
let proof =
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap();
let success = verify_proof(
root,
nullifier_hash,
signal_hash,
external_nullifier_hash,
&proof,
)
.unwrap();
for _ in 0..5 {
let success = verify_proof(
root,
nullifier_hash,
signal_hash,
external_nullifier_hash,
&proof,
)
.unwrap();
assert!(success);
}
}
#[test]
fn test_single() {
// Note that rust will still run tests in parallel
test_end_to_end(b"hello", b"appId", b"xxx");
}
assert!(success);
#[test]
fn test_parallel() {
// Note that this does not guarantee a concurrency issue will be detected.
// For that we need much more sophisticated static analysis tooling like
// loom. See <https://github.com/tokio-rs/loom>
let a = spawn(|| test_end_to_end(b"hello", b"appId", b"xxx"));
let b = spawn(|| test_end_to_end(b"secret", b"test", b"signal"));
a.join().unwrap();
b.join().unwrap();
}
}

View File

@ -116,8 +116,10 @@ pub fn generate_proof(
let now = Instant::now();
let full_assignment = WITNESS_CALCULATOR
.clone()
let full_assignment =
WITNESS_CALCULATOR
.lock()
.expect("witness_calculator mutex should not get poisoned")
.calculate_witness_element::<Bn254, _>(inputs, false)
.map_err(ProofError::WitnessError)?;
@ -178,8 +180,7 @@ mod test {
use super::*;
use crate::{hash_to_field, poseidon_tree::PoseidonTree};
#[test]
fn test_proof_serialize() {
fn arb_proof() -> Proof {
// generate identity
let id = Identity::from_seed(b"secret");
@ -194,9 +195,20 @@ mod test {
let signal_hash = hash_to_field(b"xxx");
let external_nullifier_hash = hash_to_field(b"appId");
let proof =
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap();
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap()
}
#[test]
fn test_proof_cast_roundtrip() {
let proof = arb_proof();
let ark_proof: ArkProof<Bn<Parameters>> = proof.into();
let result: Proof = ark_proof.into();
assert_eq!(proof, result);
}
#[test]
fn test_proof_serialize() {
let proof = arb_proof();
let _json = serde_json::to_value(&proof).unwrap();
// TODO: Ideally we would check the output against an expected value,

View File

@ -16,6 +16,16 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] {
output
}
pub(crate) fn bytes_to_hex<const N: usize, const M: usize>(bytes: &[u8; N]) -> [u8; M] {
// TODO: Replace `M` with a const expression once it's stable.
debug_assert_eq!(M, 2 * N + 2);
let mut result = [0u8; M];
result[0] = b'0';
result[1] = b'x';
hex::encode_to_slice(&bytes[..], &mut result[2..]).expect("the buffer is correctly sized");
result
}
/// Helper to serialize byte arrays
pub(crate) fn serialize_bytes<const N: usize, const M: usize, S: Serializer>(
serializer: S,
@ -25,10 +35,7 @@ pub(crate) fn serialize_bytes<const N: usize, const M: usize, S: Serializer>(
debug_assert_eq!(M, 2 * N + 2);
if serializer.is_human_readable() {
// Write as a 0x prefixed lower-case hex string
let mut buffer = [0u8; M];
buffer[0] = b'0';
buffer[1] = b'x';
hex::encode_to_slice(&bytes[..], &mut buffer[2..]).expect("the buffer is correctly sized");
let buffer = bytes_to_hex::<N, M>(bytes);
let string = str::from_utf8(&buffer).expect("the buffer is valid UTF-8");
serializer.serialize_str(string)
} else {