From 2a08a6320b67b776dd891cc736b65f003bd896da Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Fri, 26 Jan 2024 10:33:16 -0600 Subject: [PATCH] splitting cfg from builder construction --- src/ffi.rs | 130 +++++++++++++++++++++++++---------------------- src/ffi_types.rs | 7 --- 2 files changed, 68 insertions(+), 69 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 573abc4..bc3b07b 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -2,7 +2,6 @@ use std::{ any::Any, ffi::{c_char, CStr}, fs::File, - os::raw::c_void, panic::{catch_unwind, AssertUnwindSafe}, ptr::slice_from_raw_parts_mut, }; @@ -19,16 +18,22 @@ use crate::ffi_types::*; type GrothBn = Groth16; #[derive(Debug, Clone)] -// #[repr(C)] +struct CircomBn254Cfg { + cfg: *mut CircomConfig, + proving_key: *mut ProvingKey, + _marker: core::marker::PhantomData<(*mut CircomBn254Cfg, core::marker::PhantomPinned)>, +} + +#[derive(Debug, Clone)] struct CircomBn254 { builder: *mut CircomBuilder, - proving_key: *mut ProvingKey, _marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>, } #[derive(Debug, Clone)] struct CircomCompatCtx { - circom: *mut c_void, + circom: *mut CircomBn254, + cfg: *mut CircomBn254Cfg, rng: ThreadRng, _marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>, } @@ -43,32 +48,27 @@ fn to_err_code(result: Result<(), Box>) -> i32 { } } -/// # Safety -/// #[no_mangle] #[allow(private_interfaces)] -pub unsafe extern "C" fn init_circom_compat( +pub unsafe extern "C" fn init_circom_config( r1cs_path: *const c_char, wasm_path: *const c_char, zkey_path: *const c_char, - ctx_ptr: &mut *mut CircomCompatCtx, + cfg_ptr: &mut *mut CircomBn254Cfg, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let mut rng = thread_rng(); // TODO: use a shared rng - how? - let builder = CircomBuilder::new( - CircomConfig::::new( - CStr::from_ptr(wasm_path) - .to_str() - .map_err(|_| ERR_WASM_PATH) - .unwrap(), - CStr::from_ptr(r1cs_path) - .to_str() - .map_err(|_| ERR_R1CS_PATH) - .unwrap(), - ) - .map_err(|_| ERR_CIRCOM_BUILDER) - .unwrap(), - ); + let cfg = CircomConfig::::new( + CStr::from_ptr(wasm_path) + .to_str() + .map_err(|_| ERR_WASM_PATH) + .unwrap(), + CStr::from_ptr(r1cs_path) + .to_str() + .map_err(|_| ERR_R1CS_PATH) + .unwrap(), + ) + .map_err(|_| ERR_CIRCOM_BUILDER) + .unwrap(); let proving_key = if !zkey_path.is_null() { let mut file = File::open( @@ -84,6 +84,8 @@ pub unsafe extern "C" fn init_circom_compat( .unwrap() .0 } else { + let mut rng = thread_rng(); + let builder = CircomBuilder::new(cfg.clone()); Groth16::::generate_random_parameters_with_reduction::<_>( builder.setup(), &mut rng, @@ -92,14 +94,35 @@ pub unsafe extern "C" fn init_circom_compat( .unwrap() }; + let circom_bn254_cfg = CircomBn254Cfg { + cfg: Box::into_raw(Box::new(cfg)), + proving_key: Box::into_raw(Box::new(proving_key)), + _marker: std::marker::PhantomData, + }; + + *cfg_ptr = Box::into_raw(Box::new(circom_bn254_cfg)); + })); + + to_err_code(result) +} + +#[no_mangle] +#[allow(private_interfaces)] +pub unsafe extern "C" fn init_circom_compat( + cfg_ptr: *mut CircomBn254Cfg, + ctx_ptr: &mut *mut CircomCompatCtx, +) -> i32 { + let result = catch_unwind(AssertUnwindSafe(|| { + let rng = thread_rng(); // TODO: use a shared rng - how? + let builder = CircomBuilder::new((*(*cfg_ptr).cfg).clone()); let circom_bn254 = CircomBn254 { builder: Box::into_raw(Box::new(builder)), - proving_key: Box::into_raw(Box::new(proving_key)), _marker: core::marker::PhantomData, }; let circom_compat_ctx = CircomCompatCtx { - circom: Box::into_raw(Box::new(circom_bn254)) as *mut c_void, + circom: Box::into_raw(Box::new(circom_bn254)), + cfg: cfg_ptr, rng: rng, _marker: core::marker::PhantomData, }; @@ -116,32 +139,18 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt if !ctx_ptr.is_null() { let ctx = &mut Box::from_raw(*ctx_ptr); if !ctx.circom.is_null() { - let circom = &mut Box::from_raw(ctx.circom as *mut CircomBn254); - let _ = Box::from_raw(circom.builder); - let _ = Box::from_raw(circom.proving_key); - if !circom.builder.is_null() { - circom.builder = std::ptr::null_mut() - }; - if !circom.proving_key.is_null() { - circom.proving_key = std::ptr::null_mut() - }; - ctx.circom = std::ptr::null_mut(); + let circom = &mut Box::from_raw(ctx.circom); + let builder = Box::from_raw(circom.builder); + drop(builder); + let proving_key = Box::from_raw((*ctx.cfg).proving_key); + drop(proving_key); + let cfg = Box::from_raw((*ctx.cfg).cfg); + drop(cfg); *ctx_ptr = std::ptr::null_mut(); } } } -#[no_mangle] -// Only use if the buffer was allocated by the ffi -pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) { - if !buff_ptr.is_null() { - let buff = Box::from_raw(*buff_ptr); - let data = Box::from_raw(slice_from_raw_parts_mut(buff.data as *mut u8, buff.len)); - drop(data); - drop(buff); - } -} - #[no_mangle] pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) { if !proof_ptr.is_null() { @@ -181,8 +190,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { (*ctx_ptr).circom as *mut CircomBn254 } -/// # Safety -/// #[no_mangle] #[allow(private_interfaces)] pub unsafe extern "C" fn prove_circuit( @@ -190,8 +197,9 @@ pub unsafe extern "C" fn prove_circuit( proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { + let ctx: &mut CircomCompatCtx = &mut *ctx_ptr; let circom = &mut *to_circom(ctx_ptr); - let proving_key = &(*circom.proving_key); + let proving_key = (*ctx.cfg).proving_key; let rng = &mut (*ctx_ptr).rng; let circuit = (*circom.builder) @@ -200,7 +208,7 @@ pub unsafe extern "C" fn prove_circuit( .map_err(|_| ERR_CIRCOM_BUILDER) .unwrap(); - let circom_proof = GrothBn::prove(proving_key, circuit, rng) + let circom_proof = GrothBn::prove(&*proving_key, circuit, rng) .map_err(|_| ERR_MAKING_PROOF) .unwrap(); @@ -210,8 +218,6 @@ pub unsafe extern "C" fn prove_circuit( to_err_code(result) } -/// # Safety -/// #[no_mangle] #[allow(private_interfaces)] pub unsafe extern "C" fn get_pub_inputs( @@ -236,8 +242,6 @@ pub unsafe extern "C" fn get_pub_inputs( to_err_code(result) } -/// # Safety -/// #[no_mangle] #[allow(private_interfaces)] pub unsafe extern "C" fn get_verifying_key( @@ -245,8 +249,8 @@ pub unsafe extern "C" fn get_verifying_key( vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { - let circom = &mut *to_circom(ctx_ptr); - let proving_key = &(*circom.proving_key); + let ctx = &mut *ctx_ptr; + let proving_key = &(*(*ctx.cfg).proving_key); let vk = prepare_verifying_key(&proving_key.vk).vk; *vk_ptr = Box::leak(Box::new((&vk).into())); @@ -255,8 +259,6 @@ pub unsafe extern "C" fn get_verifying_key( to_err_code(result) } -/// # Safety -/// #[no_mangle] #[allow(private_interfaces)] pub unsafe extern "C" fn verify_circuit( @@ -275,8 +277,6 @@ pub unsafe extern "C" fn verify_circuit( to_err_code(result) } -/// # Safety -/// #[no_mangle] #[allow(private_interfaces)] pub unsafe extern "C" fn push_input_u256_array( @@ -350,11 +350,17 @@ mod test { let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap(); unsafe { - let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut(); - init_circom_compat( + let mut cfg_ptr: *mut CircomBn254Cfg = std::ptr::null_mut(); + init_circom_config( r1cs_path.as_ptr(), wasm_path.as_ptr(), std::ptr::null(), + &mut cfg_ptr, + ); + + let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut(); + init_circom_compat( + cfg_ptr, &mut ctx_ptr, ); diff --git a/src/ffi_types.rs b/src/ffi_types.rs index 7456e9a..d47127d 100644 --- a/src/ffi_types.rs +++ b/src/ffi_types.rs @@ -23,13 +23,6 @@ pub const ERR_MAKING_PROOF: i32 = 12; pub const ERR_SERIALIZE_PROOF: i32 = 13; pub const ERR_SERIALIZE_INPUTS: i32 = 14; -#[derive(Debug, Clone)] -#[repr(C)] -pub struct Buffer { - pub data: *const u8, - pub len: usize, -} - // Helper for converting a PrimeField to little endian byte slice fn slice_to_point(point: &[u8; 32]) -> F { let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");