From b05238f09a3b66ed5dc829d288fea1eb9d39570b Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Thu, 6 Mar 2025 21:27:16 -0600 Subject: [PATCH] wip --- Cargo.toml | 4 ++ src/ffi.rs | 177 ++++++++++++++++++++++++++++++++++------------- src/ffi_types.rs | 66 +++++++++++++----- 3 files changed, 178 insertions(+), 69 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9557d1d..57f20ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,10 @@ opt-level = 3 # Use slightly better optimizations. debug = true # Generate debug info. debug-assertions = true # Enable debug assertions. +[profile.release] +opt-level = 3 # Maximum optimization level +debug = true + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] crate-type = [ diff --git a/src/ffi.rs b/src/ffi.rs index bd25f86..92570ce 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -3,7 +3,6 @@ use std::{ ffi::{c_char, CStr}, fs::File, panic::{catch_unwind, AssertUnwindSafe}, - ptr::slice_from_raw_parts_mut, }; use crate::ffi_types::*; @@ -13,6 +12,7 @@ use ark_crypto_primitives::snark::SNARK; use ark_groth16::{prepare_verifying_key, Groth16, ProvingKey}; use ark_std::rand::thread_rng; use ruint::aliases::U256; +use std::sync::{Arc, Mutex}; type GrothBn = Groth16; @@ -35,21 +35,18 @@ pub const ERR_SERIALIZE_INPUTS: i32 = 14; #[derive(Debug, Clone)] struct CircomBn254Cfg { - cfg: *mut CircomConfig, - proving_key: *mut ProvingKey, - _marker: core::marker::PhantomData<(*mut CircomBn254Cfg, core::marker::PhantomPinned)>, + cfg: Arc>>, + proving_key: Arc>>, } #[derive(Debug, Clone)] struct CircomBn254 { - builder: *mut CircomBuilder, - _marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>, + builder: Arc>>, } #[derive(Debug, Clone)] struct CircomCompatCtx { - circom: *mut CircomBn254, - _marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>, + circom: Arc>, } fn to_err_code(result: Result>) -> i32 { @@ -108,9 +105,8 @@ pub unsafe extern "C" fn init_circom_config_with_checks( }; let circom_bn254_cfg = CircomBn254Cfg { - cfg: Box::into_raw(Box::new(cfg)), - proving_key: Box::into_raw(Box::new(proving_key)), - _marker: std::marker::PhantomData, + cfg: Arc::new(Mutex::new(cfg)), + proving_key: Arc::new(Mutex::new(proving_key)), }; *cfg_ptr = Box::into_raw(Box::new(circom_bn254_cfg)); @@ -139,15 +135,13 @@ pub unsafe extern "C" fn init_circom_compat( ctx_ptr: &mut *mut CircomCompatCtx, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let builder = CircomBuilder::new((*(*cfg_ptr).cfg).clone()); // clone the config + let builder = CircomBuilder::new((*(*cfg_ptr).cfg.lock().unwrap()).clone()); // clone the config let circom_bn254 = CircomBn254 { - builder: Box::into_raw(Box::new(builder)), - _marker: core::marker::PhantomData, + builder: Arc::new(Mutex::new(builder)), }; let circom_compat_ctx = CircomCompatCtx { - circom: Box::into_raw(Box::new(circom_bn254)), - _marker: core::marker::PhantomData, + circom: Arc::new(Mutex::new(circom_bn254)), }; *ctx_ptr = Box::into_raw(Box::new(circom_compat_ctx)); @@ -162,14 +156,7 @@ pub unsafe extern "C" fn init_circom_compat( #[allow(private_interfaces)] pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCtx) { if !ctx_ptr.is_null() { - let ctx = &mut Box::from_raw(*ctx_ptr); - - if !ctx.circom.is_null() { - let circom = &mut Box::from_raw(ctx.circom); - let builder = Box::from_raw(circom.builder); - drop(builder); - } - + drop(Box::from_raw(*ctx_ptr)); *ctx_ptr = std::ptr::null_mut(); } } @@ -179,8 +166,6 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt pub unsafe extern "C" fn release_cfg(cfg_ptr: &mut *mut CircomBn254Cfg) { if !cfg_ptr.is_null() && !(*cfg_ptr).is_null() { let cfg = Box::from_raw(*cfg_ptr); - drop(Box::from_raw((*cfg).proving_key)); - drop(Box::from_raw((*cfg).cfg)); drop(cfg); *cfg_ptr = std::ptr::null_mut(); } @@ -199,11 +184,7 @@ pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) { pub unsafe extern "C" fn release_inputs(inputs_ptr: &mut *mut Inputs) { if !inputs_ptr.is_null() { let inputs = Box::from_raw(*inputs_ptr); - let elms = Box::from_raw(slice_from_raw_parts_mut( - inputs.elms as *mut [u8; 32], - inputs.len, - )); - drop(elms); + inputs.free(); drop(inputs); *inputs_ptr = std::ptr::null_mut(); } @@ -214,16 +195,15 @@ pub unsafe extern "C" fn release_inputs(inputs_ptr: &mut *mut Inputs) { pub unsafe extern "C" fn release_key(key_ptr: &mut *mut VerifyingKey) { if !key_ptr.is_null() { let key = Box::from_raw(*key_ptr); - let ic= Box::from_raw(slice_from_raw_parts_mut(key.ic as *mut G1, key.ic_len)); - drop(ic); + key.free(); drop(key); *key_ptr = std::ptr::null_mut(); } } -unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { - (*ctx_ptr).circom as *mut CircomBn254 -} +// unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { +// (*ctx_ptr).circom.lock().unwrap() as *mut CircomBn254 +// } #[no_mangle] #[allow(private_interfaces)] @@ -233,11 +213,11 @@ pub unsafe extern "C" fn prove_circuit( proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let circom = &mut *to_circom(ctx_ptr); - let proving_key = (*(*cfg_ptr).proving_key).clone(); + let circom = ((*ctx_ptr).circom.lock().unwrap()).clone(); + let proving_key = (*(*cfg_ptr).proving_key.lock().unwrap()).clone(); let mut rng = thread_rng(); - let circuit = (*circom.builder) + let circuit = (*circom.builder.lock().unwrap()) .clone() .build() .map_err(|_| ERR_CIRCOM_BUILDER) @@ -247,7 +227,7 @@ pub unsafe extern "C" fn prove_circuit( .map_err(|_| ERR_MAKING_PROOF) .unwrap(); - *proof_ptr = Box::leak(Box::new((&circom_proof).into())); + *proof_ptr = Box::into_raw(Box::new((&circom_proof).into())); ERR_OK })); @@ -262,8 +242,8 @@ pub unsafe extern "C" fn get_pub_inputs( inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let circom = &mut *to_circom(ctx_ptr); - let circuit = (*circom.builder) + let circom = (*ctx_ptr).circom.lock().unwrap(); + let circuit = (*circom.builder.lock().unwrap()) .clone() .build() .map_err(|_| ERR_CIRCOM_BUILDER) @@ -273,7 +253,7 @@ pub unsafe extern "C" fn get_pub_inputs( .get_public_inputs() .ok_or_else(|| ERR_GET_PUB_INPUTS) .unwrap(); - *inputs_ptr = Box::leak(Box::new(inputs.as_slice().into())); + *inputs_ptr = Box::into_raw(Box::new(inputs.as_slice().into())); ERR_OK })); @@ -289,7 +269,7 @@ pub unsafe extern "C" fn get_verifying_key( ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { let ctx = &mut *cfg_ptr; - let proving_key = &(*(*ctx).proving_key); + let proving_key = &(*(*ctx).proving_key.lock().unwrap()); let vk = prepare_verifying_key(&proving_key.vk).vk; *vk_ptr = Box::into_raw(Box::new((&vk).into())); @@ -308,7 +288,7 @@ pub unsafe extern "C" fn verify_circuit( pvk: *const VerifyingKey, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let inputs_vec: Vec = (*inputs).into(); + let inputs_vec: Vec = (*inputs).clone().into(); let prepared_key = prepare_verifying_key(&(*pvk).into()); let passed = GrothBn::verify_proof(&prepared_key, &(*proof).into(), inputs_vec.as_slice()) @@ -350,10 +330,10 @@ pub unsafe extern "C" fn push_input_u256_array( .map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap()) .collect::>(); - let circom = &mut *to_circom(ctx_ptr); + let circom = (*ctx_ptr).circom.lock().unwrap(); inputs .iter() - .for_each(|c| (*circom.builder).push_input(name, *c)); + .for_each(|c| (*circom.builder.lock().unwrap()).push_input(name, *c)); ERR_OK })); @@ -375,8 +355,8 @@ macro_rules! build_fn let name = CStr::from_ptr(name_ptr).to_str().map_err(|_| ERR_INPUT_NAME).unwrap(); let input = U256::from(input); - let circom = &mut *to_circom(ctx_ptr); - (*circom.builder).push_input(name, input); + let circom = (*ctx_ptr).circom.lock().unwrap(); + (*circom.builder.lock().unwrap()).push_input(name, input); ERR_OK })); @@ -400,7 +380,6 @@ mod test { use std::ffi::CString; #[test] - #[ignore] // TODO: getting alignment issues for this circuit, need to investigate further fn proof_verify() { let r1cs_path = CString::new("./fixtures/circom2_multiplier2.r1cs".as_bytes()).unwrap(); let wasm_path = CString::new("./fixtures/circom2_multiplier2.wasm".as_bytes()).unwrap(); @@ -519,4 +498,102 @@ mod test { assert!(cfg_ptr == std::ptr::null_mut()); }; } + + // Wrapper to make raw pointers safely sendable between threads + // This is safe because we know the pointers are protected by Arc> internally + struct ThreadSafePointer(*const T); + unsafe impl Send for ThreadSafePointer {} + + #[test] + fn multithreaded_prove_and_verify() { + use std::sync::{Arc, Barrier}; + use std::thread; + + let r1cs_path = CString::new("./fixtures/circom2_multiplier2.r1cs".as_bytes()).unwrap(); + let wasm_path = CString::new("./fixtures/circom2_multiplier2.wasm".as_bytes()).unwrap(); + let zkey_path = CString::new("./fixtures/test.zkey".as_bytes()).unwrap(); + + unsafe { + // Setup in the main thread + let mut cfg_ptr: *mut CircomBn254Cfg = std::ptr::null_mut(); + init_circom_config( + r1cs_path.as_ptr(), + wasm_path.as_ptr(), + zkey_path.as_ptr(), + &mut cfg_ptr, + ); + assert!(cfg_ptr != std::ptr::null_mut()); + + let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut(); + init_circom_compat(cfg_ptr, &mut ctx_ptr); + assert!(ctx_ptr != std::ptr::null_mut()); + + // Push inputs + let a = CString::new("a".as_bytes()).unwrap(); + push_input_i8(ctx_ptr, a.as_ptr(), 3); + + let b = CString::new("b".as_bytes()).unwrap(); + push_input_i8(ctx_ptr, b.as_ptr(), 11); + + // Create a barrier to synchronize threads + let barrier = Arc::new(Barrier::new(2)); + let barrier_clone = barrier.clone(); + + // Wrap the pointers in a thread-safe container + let thread_cfg_ptr = Arc::new(Mutex::new(ThreadSafePointer( + cfg_ptr as *const CircomBn254Cfg, + ))); + let thread_ctx_ptr = Arc::new(Mutex::new(ThreadSafePointer( + ctx_ptr as *const CircomCompatCtx, + ))); + + let thread_cfg_ptr_clone = Arc::clone(&thread_cfg_ptr); + let thread_ctx_ptr_clone = Arc::clone(&thread_ctx_ptr); + + // Spawn thread for both proving and verification + let prover_thread = thread::spawn(move || { + // Unwrap the pointers + let thread_cfg_ptr = thread_cfg_ptr_clone.lock().unwrap().0 as *mut CircomBn254Cfg; + let thread_ctx_ptr = thread_ctx_ptr_clone.lock().unwrap().0 as *mut CircomCompatCtx; + + // Create pointers for proof, inputs, and verification key + let mut proof_ptr: *mut Proof = std::ptr::null_mut(); + let mut inputs_ptr: *mut Inputs = std::ptr::null_mut(); + let mut vk_ptr: *mut VerifyingKey = std::ptr::null_mut(); + + // Get public inputs + assert!(get_pub_inputs(thread_ctx_ptr, &mut inputs_ptr) == ERR_OK); + assert!(inputs_ptr != std::ptr::null_mut()); + + // Generate proof + assert!(prove_circuit(thread_cfg_ptr, thread_ctx_ptr, &mut proof_ptr) == ERR_OK); + assert!(proof_ptr != std::ptr::null_mut()); + + // Get verification key + assert!(get_verifying_key(thread_cfg_ptr, &mut vk_ptr) == ERR_OK); + assert!(vk_ptr != std::ptr::null_mut()); + + // Verify the proof in the same thread + assert!(verify_circuit(&(*proof_ptr), &(*inputs_ptr), &(*vk_ptr)) == ERR_OK); + + // Clean up resources in this thread + release_proof(&mut proof_ptr); + release_inputs(&mut inputs_ptr); + release_key(&mut vk_ptr); + + // Signal that we're done + barrier_clone.wait(); + }); + + // Wait for the prover thread to complete + barrier.wait(); + + // Clean up remaining resources in the main thread + release_circom_compat(&mut ctx_ptr); + release_cfg(&mut cfg_ptr); + + // Wait for the thread to finish + prover_thread.join().unwrap(); + }; + } } diff --git a/src/ffi_types.rs b/src/ffi_types.rs index 098f793..57beb57 100644 --- a/src/ffi_types.rs +++ b/src/ffi_types.rs @@ -1,4 +1,4 @@ -use std::ptr::slice_from_raw_parts; +use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut}; use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine}; use ark_ff::{BigInteger, PrimeField}; @@ -6,7 +6,7 @@ use ark_serialize::CanonicalDeserialize; use ark_std::Zero; // Helper for converting a PrimeField to little endian byte slice -fn slice_to_point(point: &[u8; 32]) -> F { +fn slice_to_point(point: [u8; 32]) -> F { let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works"); F::from_bigint(bigint).unwrap() } @@ -90,8 +90,8 @@ impl From<&ark_groth16::Proof> for Proof { impl From for G1Affine { fn from(src: G1) -> Self { - let x: Fq = slice_to_point(&src.x); - let y: Fq = slice_to_point(&src.y); + let x: Fq = slice_to_point(src.x); + let y: Fq = slice_to_point(src.y); if x.is_zero() && y.is_zero() { G1Affine::identity() } else { @@ -102,12 +102,12 @@ impl From for G1Affine { impl From for G2Affine { fn from(src: G2) -> G2Affine { - let c0 = slice_to_point(&src.x[0]); - let c1 = slice_to_point(&src.x[1]); + let c0 = slice_to_point(src.x[0]); + let c1 = slice_to_point(src.x[1]); let x = Fq2::new(c0, c1); - let c0 = slice_to_point(&src.y[0]); - let c1 = slice_to_point(&src.y[1]); + let c0 = slice_to_point(src.y[0]); + let c1 = slice_to_point(src.y[1]); let y = Fq2::new(c0, c1); if x.is_zero() && y.is_zero() { @@ -136,7 +136,7 @@ impl From for ark_groth16::VerifyingKey { gamma_g2: src.gamma2.into(), delta_g2: src.delta2.into(), gamma_abc_g1: unsafe { - std::slice::from_raw_parts(src.ic, src.ic_len) + (*slice_from_raw_parts(src.ic, src.ic_len)) .iter() .map(|p| (*p).into()) .collect() @@ -145,10 +145,39 @@ impl From for ark_groth16::VerifyingKey { } } +impl VerifyingKey { + pub fn free(mut self) { + unsafe { + if !self.ic.is_null() && self.ic_len > 0 { + drop(Box::from_raw(slice_from_raw_parts_mut( + self.ic as *mut G1, + self.ic_len, + ))); + self.ic = std::ptr::null(); + self.ic_len = 0; + } + } + } +} + +impl Inputs { + pub fn free(mut self) { + unsafe { + if !self.elms.is_null() && self.len > 0 { + drop(Box::from_raw(slice_from_raw_parts_mut( + self.elms as *mut [u8; 32], + self.len, + ))); + self.elms = std::ptr::null(); + self.len = 0; + } + } + } +} + impl From<&ark_groth16::VerifyingKey> for VerifyingKey { fn from(vk: &ark_groth16::VerifyingKey) -> Self { - let mut ic: Vec = vk.gamma_abc_g1.iter().map(|p| p.into()).collect(); - ic.shrink_to_fit(); + let ic: Vec = vk.gamma_abc_g1.iter().map(|p| p.into()).collect(); let len = ic.len(); Self { @@ -164,12 +193,10 @@ impl From<&ark_groth16::VerifyingKey> for VerifyingKey { impl From<&[Fr]> for Inputs { fn from(src: &[Fr]) -> Self { - let mut els: Vec<[u8; 32]> = src.iter().map(|point| point_to_slice(*point)).collect(); - - els.shrink_to_fit(); + let els: Vec<[u8; 32]> = src.iter().map(|point| point_to_slice(*point)).collect(); let len = els.len(); Self { - elms: Box::leak(els.into_boxed_slice()).as_ptr(), + elms: Box::into_raw(els.into_boxed_slice()) as *const [u8; 32], len: len, } } @@ -178,9 +205,9 @@ impl From<&[Fr]> for Inputs { impl From for Vec { fn from(src: Inputs) -> Self { let els: Vec = unsafe { - (&*slice_from_raw_parts(src.elms, src.len)) + (*slice_from_raw_parts(src.elms, src.len)) .iter() - .map(|point| slice_to_point(point)) + .map(|point| slice_to_point(*point)) .collect() }; @@ -216,7 +243,7 @@ mod test { fn convert_fq() { let el = fq(); let el2 = point_to_slice(el); - let el3: Fq = slice_to_point(&el2); + let el3: Fq = slice_to_point(el2); let el4 = point_to_slice(el3); assert_eq!(el, el3); assert_eq!(el2, el4); @@ -226,7 +253,7 @@ mod test { fn convert_fr() { let el = fr(); let el2 = point_to_slice(el); - let el3: Fr = slice_to_point(&el2); + let el3: Fr = slice_to_point(el2); let el4 = point_to_slice(el3); assert_eq!(el, el3); assert_eq!(el2, el4); @@ -264,6 +291,7 @@ mod test { let vk_ffi = &VerifyingKey::from(&vk); let ark_vk: ark_groth16::VerifyingKey = (*vk_ffi).into(); assert_eq!(ark_vk, vk); + vk_ffi.free(); } #[test]