From 6028d1e0fc9d2cea9f4a0d7a6a55eca13d9f1017 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Thu, 25 Jan 2024 15:27:31 -0600 Subject: [PATCH] add release methods --- src/ffi.rs | 150 +++++++++++++++++++++++++++++++++-------------- src/ffi_types.rs | 8 +-- 2 files changed, 109 insertions(+), 49 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 816de13..7e8d7a1 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -1,16 +1,11 @@ use std::{ - any::Any, - ffi::{c_char, CStr}, - fs::File, - os::raw::c_void, - panic::{catch_unwind, AssertUnwindSafe}, + any::Any, ffi::{c_char, CStr}, fs::File, os::raw::c_void, panic::{catch_unwind, AssertUnwindSafe}, ptr::slice_from_raw_parts_mut }; use ark_bn254::{Bn254, Fr}; use ark_circom::{read_zkey, CircomBuilder, CircomConfig}; use ark_crypto_primitives::snark::SNARK; -use ark_groth16::{prepare_verifying_key, Groth16, Proof as Groth16Proof, ProvingKey}; -use ark_serialize::Compress; +use ark_groth16::{prepare_verifying_key, Groth16, ProvingKey}; use ark_std::rand::{rngs::ThreadRng, thread_rng}; use ruint::aliases::U256; @@ -135,11 +130,45 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt // Only use if the buffer was allocated by the ffi pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) { if !buff_ptr.is_null() { - let buff = &mut Box::from_raw(*buff_ptr); - let _ = Box::from_raw(buff.data as *mut u8); - buff.data = std::ptr::null_mut(); - buff.len = 0; - *buff_ptr = std::ptr::null_mut(); + let buff = Box::from_raw(*buff_ptr); + let data = Box::from_raw(slice_from_raw_parts_mut(buff.data as *mut u8, buff.len)); + drop(data); + drop(buff); + } +} + +#[no_mangle] +pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) { + if !proof_ptr.is_null() { + drop(Box::from_raw(*proof_ptr)); + *proof_ptr = std::ptr::null_mut(); + } +} + +#[no_mangle] +// Only use if the buffer was allocated by the ffi +pub unsafe extern "C" fn release_inputs(inputs_ptr: &mut *mut Inputs) { + if !inputs_ptr.is_null() { + let inputs = Box::from_raw(*inputs_ptr); + let elms = Box::from_raw(slice_from_raw_parts_mut( + inputs.elms as *mut [u8; 32], + inputs.len, + )); + drop(elms); + drop(inputs); + *inputs_ptr = std::ptr::null_mut(); + } +} + +#[no_mangle] +// Only use if the buffer was allocated by the ffi +pub unsafe extern "C" fn release_key(key_ptr: &mut *mut VerifyingKey) { + if !key_ptr.is_null() { + let key = Box::from_raw(*key_ptr); + let ic: Box<[G1]> = Box::from_raw(slice_from_raw_parts_mut(key.ic as *mut G1, key.ic_len)); + drop(ic); + drop(key); + *key_ptr = std::ptr::null_mut(); } } @@ -153,14 +182,39 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { #[allow(private_interfaces)] pub unsafe extern "C" fn prove_circuit( ctx_ptr: *mut CircomCompatCtx, - proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer, - inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer, + proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { let circom = &mut *to_circom(ctx_ptr); let proving_key = &(*circom.proving_key); let rng = &mut (*ctx_ptr).rng; + let circuit = (*circom.builder) + .clone() + .build() + .map_err(|_| ERR_CIRCOM_BUILDER) + .unwrap(); + + let circom_proof = GrothBn::prove(proving_key, circuit, rng) + .map_err(|_| ERR_MAKING_PROOF) + .unwrap(); + + *proof_ptr = Box::leak(Box::new((&circom_proof).into())); + })); + + to_err_code(result) +} + +/// # Safety +/// +#[no_mangle] +#[allow(private_interfaces)] +pub unsafe extern "C" fn get_pub_inputs( + ctx_ptr: *mut CircomCompatCtx, + inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer, +) -> i32 { + let result = catch_unwind(AssertUnwindSafe(|| { + let circom = &mut *to_circom(ctx_ptr); let circuit = (*circom.builder) .clone() .build() @@ -171,12 +225,6 @@ pub unsafe extern "C" fn prove_circuit( .get_public_inputs() .ok_or_else(|| ERR_GET_PUB_INPUTS) .unwrap(); - - let circomProof = GrothBn::prove(proving_key, circuit, rng) - .map_err(|_| ERR_MAKING_PROOF) - .unwrap(); - - *proof_ptr = Box::leak(Box::new((&circomProof).into())); *inputs_ptr = Box::leak(Box::new(inputs.as_slice().into())); })); @@ -187,24 +235,34 @@ pub unsafe extern "C" fn prove_circuit( /// #[no_mangle] #[allow(private_interfaces)] -pub unsafe extern "C" fn verify_circuit( +pub unsafe extern "C" fn get_verifying_key( ctx_ptr: *mut CircomCompatCtx, - compress: bool, - proof: *const Proof, - inputs: *const Inputs, + vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let mode = match compress { - true => Compress::Yes, - false => Compress::No, - }; - let circom = &mut *to_circom(ctx_ptr); let proving_key = &(*circom.proving_key); - let pvk = prepare_verifying_key(&proving_key.vk); + let vk = prepare_verifying_key(&proving_key.vk).vk; + *vk_ptr = Box::leak(Box::new((&vk).into())); + })); + + to_err_code(result) +} + +/// # Safety +/// +#[no_mangle] +#[allow(private_interfaces)] +pub unsafe extern "C" fn verify_circuit( + proof: *const Proof, + inputs: *const Inputs, + pvk: *const VerifyingKey, +) -> i32 { + let result = catch_unwind(AssertUnwindSafe(|| { let inputs_vec: Vec = (*inputs).into(); - GrothBn::verify_proof(&pvk, &(*proof).into(), inputs_vec.as_slice()) + let prepared_key = prepare_verifying_key(&(*pvk).into()); + GrothBn::verify_proof(&prepared_key, &(*proof).into(), inputs_vec.as_slice()) .map_err(|_| ERR_FAILED_TO_VERIFY_PROOF) .unwrap(); })); @@ -278,8 +336,8 @@ build_fn!(push_input_u64, x: u64); #[cfg(test)] mod test { - use std::ffi::CString; use super::*; + use std::ffi::CString; #[test] fn proof_verify() { @@ -305,23 +363,27 @@ mod test { let mut proof_ptr: *mut Proof = std::ptr::null_mut(); let mut inputs_ptr: *mut Inputs = std::ptr::null_mut(); + let mut vk_ptr: *mut VerifyingKey = std::ptr::null_mut(); - assert!(prove_circuit(ctx_ptr, &mut proof_ptr, &mut inputs_ptr) == ERR_OK); - - assert!(proof_ptr != std::ptr::null_mut()); + assert!(get_pub_inputs(ctx_ptr, &mut inputs_ptr) == ERR_OK); assert!(inputs_ptr != std::ptr::null_mut()); - assert!( - verify_circuit(ctx_ptr, true, &(*proof_ptr), &(*inputs_ptr)) == ERR_OK - ); + assert!(prove_circuit(ctx_ptr, &mut proof_ptr) == ERR_OK); + assert!(proof_ptr != std::ptr::null_mut()); - // release_buffer(&mut proof_bytes_ptr); - // release_buffer(&mut inputs_bytes_ptr); - // release_circom_compat(&mut ctx_ptr); + assert!(get_verifying_key(ctx_ptr, &mut vk_ptr) == ERR_OK); + assert!(vk_ptr != std::ptr::null_mut()); - // assert!(ctx_ptr == std::ptr::null_mut()); - // assert!(proof_bytes_ptr == std::ptr::null_mut()); - // assert!(inputs_bytes_ptr == std::ptr::null_mut()); + assert!(verify_circuit(&(*proof_ptr), &(*inputs_ptr), &(*vk_ptr)) == ERR_OK); + + release_inputs(&mut inputs_ptr); + assert!(inputs_ptr == std::ptr::null_mut()); + + release_proof(&mut proof_ptr); + assert!(proof_ptr == std::ptr::null_mut()); + + release_key(&mut vk_ptr); + assert!(vk_ptr == std::ptr::null_mut()); }; } } diff --git a/src/ffi_types.rs b/src/ffi_types.rs index 787513c..7456e9a 100644 --- a/src/ffi_types.rs +++ b/src/ffi_types.rs @@ -2,13 +2,10 @@ use std::ptr::slice_from_raw_parts; use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine}; use ark_ff::{BigInteger, PrimeField}; -use ark_groth16::{Groth16, Proof as Groth16Proof}; use ark_serialize::CanonicalDeserialize; use ark_std::Zero; use num_bigint::BigUint; -type GrothBn = Groth16; - pub const ERR_UNKNOWN: i32 = -1; pub const ERR_OK: i32 = 0; pub const ERR_WASM_PATH: i32 = 1; @@ -106,8 +103,8 @@ impl From<&G2Affine> for G2 { } } -impl From<&Groth16Proof> for Proof { - fn from(src: &Groth16Proof) -> Self { +impl From<&ark_groth16::Proof> for Proof { + fn from(src: &ark_groth16::Proof) -> Self { Self { a: (&src.a).into(), b: (&src.b).into(), @@ -219,6 +216,7 @@ impl From for Vec { } } +#[cfg(test)] mod test { use ark_std::UniformRand;