diff --git a/src/ffi.rs b/src/ffi.rs index 24c8e9b..4938ed6 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -3,7 +3,7 @@ use std::{ ffi::{c_char, CStr}, fs::File, os::raw::c_void, - panic::{catch_unwind, AssertUnwindSafe}, + panic::{catch_unwind, AssertUnwindSafe}, io::Cursor, }; use ark_bn254::{Bn254, Fr}; @@ -28,6 +28,10 @@ pub const ERR_CANT_READ_ZKEY: i32 = 6; pub const ERR_CIRCOM_BUILDER: i32 = 7; pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8; pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9; +pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10; +pub const ERR_GET_PUB_INPUTS: i32 = 11; +pub const ERR_MAKING_PROOF: i32 = 12; +pub const ERR_SERIALIZE_PROOF: i32 = 13; #[derive(Debug, Clone)] #[repr(C)] @@ -150,6 +154,7 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt } #[no_mangle] +// Only use if the buffer was allocated by the ffi pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) { if !buff_ptr.is_null() { let buff = &mut Box::from_raw(*buff_ptr); @@ -164,39 +169,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { (*ctx_ptr).circom as *mut CircomBn254 } -/// # Safety -/// -#[no_mangle] -#[allow(private_interfaces)] -pub unsafe extern "C" fn push_input_u256_array( - ctx_ptr: *mut CircomCompatCtx, - name_ptr: *const c_char, - input_ptr: *const u8, - len: usize, -) -> i32 { - let result = catch_unwind(AssertUnwindSafe(|| { - let name = CStr::from_ptr(name_ptr) - .to_str() - .map_err(|_| ERR_INPUT_NAME) - .unwrap(); - - let input = { - let slice = std::slice::from_raw_parts(input_ptr, len); - slice - .chunks(U256::BYTES) - .map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap()) - .collect::>() - }; - - let circom = &mut *to_circom(ctx_ptr); - input - .iter() - .for_each(|c| (*circom.builder).push_input(name, *c)); - })); - - to_err_code(result) -} - /// # Safety /// #[no_mangle] @@ -212,10 +184,19 @@ pub unsafe extern "C" fn prove_circuit( let proving_key = &(*circom.proving_key); let rng = &mut (*ctx_ptr).rng; - let circuit = (*circom.builder).clone().build().unwrap(); + let circuit = (*circom.builder) + .clone() + .build() + .map_err(|_| ERR_CIRCOM_BUILDER) + .unwrap(); - let inputs = circuit.get_public_inputs().unwrap(); - let proof = GrothBn::prove(&proving_key, circuit, rng).unwrap(); + let inputs = circuit + .get_public_inputs() + .ok_or_else(|| ERR_GET_PUB_INPUTS) + .unwrap(); + let proof = GrothBn::prove(&proving_key, circuit, rng) + .map_err(|_| ERR_MAKING_PROOF) + .unwrap(); let mut proof_bytes = Vec::new(); proof.serialize_compressed(&mut proof_bytes).unwrap(); @@ -223,6 +204,7 @@ pub unsafe extern "C" fn prove_circuit( let mut public_inputs_bytes = Vec::new(); inputs .serialize_compressed(&mut public_inputs_bytes) + .map_err(|_| ERR_SERIALIZE_PROOF) .unwrap(); // leak the buffers to avoid rust from freeing the pointed to data, @@ -278,13 +260,46 @@ pub unsafe extern "C" fn verify_circuit( let pvk = prepare_verifying_key(&proving_key.vk); GrothBn::verify_proof(&pvk, &proof, &public_inputs) - .map_err(|e| e.to_string()) + .map_err(|_| ERR_FAILED_TO_VERIFY_PROOF) .unwrap(); })); to_err_code(result) } +/// # Safety +/// +#[no_mangle] +#[allow(private_interfaces)] +pub unsafe extern "C" fn push_input_u256_array( + ctx_ptr: *mut CircomCompatCtx, + name_ptr: *const c_char, + input_ptr: *const u8, + len: usize, +) -> i32 { + let result = catch_unwind(AssertUnwindSafe(|| { + let name = CStr::from_ptr(name_ptr) + .to_str() + .map_err(|_| ERR_INPUT_NAME) + .unwrap(); + + let input = { + let slice = std::slice::from_raw_parts(input_ptr, len); + slice + .chunks(U256::BYTES) + .map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap()) + .collect::>() + }; + + let circom = &mut *to_circom(ctx_ptr); + input + .iter() + .for_each(|c| (*circom.builder).push_input(name, *c)); + })); + + to_err_code(result) +} + macro_rules! build_fn { ($name:tt, $($v:ident: $t:ty),*) => { @@ -296,7 +311,7 @@ macro_rules! build_fn input: $($t),* ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let name = CStr::from_ptr(name_ptr).to_str().unwrap(); + let name = CStr::from_ptr(name_ptr).to_str().map_err(|_| ERR_INPUT_NAME).unwrap(); let input = U256::from(input); let circom = &mut *to_circom(ctx_ptr); @@ -323,7 +338,7 @@ mod test { use super::*; #[test] - fn groth16_proof() { + fn proof_verify() { let r1cs_path = CString::new("./fixtures/mycircuit.r1cs".as_bytes()).unwrap(); let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();