splitting cfg from builder construction

This commit is contained in:
Dmitriy Ryajov 2024-01-26 10:33:16 -06:00
parent d6299d56ed
commit 2a08a6320b
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
2 changed files with 68 additions and 69 deletions

View File

@ -2,7 +2,6 @@ use std::{
any::Any, any::Any,
ffi::{c_char, CStr}, ffi::{c_char, CStr},
fs::File, fs::File,
os::raw::c_void,
panic::{catch_unwind, AssertUnwindSafe}, panic::{catch_unwind, AssertUnwindSafe},
ptr::slice_from_raw_parts_mut, ptr::slice_from_raw_parts_mut,
}; };
@ -19,16 +18,22 @@ use crate::ffi_types::*;
type GrothBn = Groth16<Bn254>; type GrothBn = Groth16<Bn254>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
// #[repr(C)] struct CircomBn254Cfg {
cfg: *mut CircomConfig<Bn254>,
proving_key: *mut ProvingKey<Bn254>,
_marker: core::marker::PhantomData<(*mut CircomBn254Cfg, core::marker::PhantomPinned)>,
}
#[derive(Debug, Clone)]
struct CircomBn254 { struct CircomBn254 {
builder: *mut CircomBuilder<Bn254>, builder: *mut CircomBuilder<Bn254>,
proving_key: *mut ProvingKey<Bn254>,
_marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>, _marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct CircomCompatCtx { struct CircomCompatCtx {
circom: *mut c_void, circom: *mut CircomBn254,
cfg: *mut CircomBn254Cfg,
rng: ThreadRng, rng: ThreadRng,
_marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>, _marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>,
} }
@ -43,20 +48,16 @@ fn to_err_code(result: Result<(), Box<dyn Any + Send>>) -> i32 {
} }
} }
/// # Safety
///
#[no_mangle] #[no_mangle]
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn init_circom_compat( pub unsafe extern "C" fn init_circom_config(
r1cs_path: *const c_char, r1cs_path: *const c_char,
wasm_path: *const c_char, wasm_path: *const c_char,
zkey_path: *const c_char, zkey_path: *const c_char,
ctx_ptr: &mut *mut CircomCompatCtx, cfg_ptr: &mut *mut CircomBn254Cfg,
) -> i32 { ) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| { let result = catch_unwind(AssertUnwindSafe(|| {
let mut rng = thread_rng(); // TODO: use a shared rng - how? let cfg = CircomConfig::<Bn254>::new(
let builder = CircomBuilder::new(
CircomConfig::<Bn254>::new(
CStr::from_ptr(wasm_path) CStr::from_ptr(wasm_path)
.to_str() .to_str()
.map_err(|_| ERR_WASM_PATH) .map_err(|_| ERR_WASM_PATH)
@ -67,8 +68,7 @@ pub unsafe extern "C" fn init_circom_compat(
.unwrap(), .unwrap(),
) )
.map_err(|_| ERR_CIRCOM_BUILDER) .map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap(), .unwrap();
);
let proving_key = if !zkey_path.is_null() { let proving_key = if !zkey_path.is_null() {
let mut file = File::open( let mut file = File::open(
@ -84,6 +84,8 @@ pub unsafe extern "C" fn init_circom_compat(
.unwrap() .unwrap()
.0 .0
} else { } else {
let mut rng = thread_rng();
let builder = CircomBuilder::new(cfg.clone());
Groth16::<Bn254>::generate_random_parameters_with_reduction::<_>( Groth16::<Bn254>::generate_random_parameters_with_reduction::<_>(
builder.setup(), builder.setup(),
&mut rng, &mut rng,
@ -92,14 +94,35 @@ pub unsafe extern "C" fn init_circom_compat(
.unwrap() .unwrap()
}; };
let circom_bn254_cfg = CircomBn254Cfg {
cfg: Box::into_raw(Box::new(cfg)),
proving_key: Box::into_raw(Box::new(proving_key)),
_marker: std::marker::PhantomData,
};
*cfg_ptr = Box::into_raw(Box::new(circom_bn254_cfg));
}));
to_err_code(result)
}
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn init_circom_compat(
cfg_ptr: *mut CircomBn254Cfg,
ctx_ptr: &mut *mut CircomCompatCtx,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let rng = thread_rng(); // TODO: use a shared rng - how?
let builder = CircomBuilder::new((*(*cfg_ptr).cfg).clone());
let circom_bn254 = CircomBn254 { let circom_bn254 = CircomBn254 {
builder: Box::into_raw(Box::new(builder)), builder: Box::into_raw(Box::new(builder)),
proving_key: Box::into_raw(Box::new(proving_key)),
_marker: core::marker::PhantomData, _marker: core::marker::PhantomData,
}; };
let circom_compat_ctx = CircomCompatCtx { let circom_compat_ctx = CircomCompatCtx {
circom: Box::into_raw(Box::new(circom_bn254)) as *mut c_void, circom: Box::into_raw(Box::new(circom_bn254)),
cfg: cfg_ptr,
rng: rng, rng: rng,
_marker: core::marker::PhantomData, _marker: core::marker::PhantomData,
}; };
@ -116,32 +139,18 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
if !ctx_ptr.is_null() { if !ctx_ptr.is_null() {
let ctx = &mut Box::from_raw(*ctx_ptr); let ctx = &mut Box::from_raw(*ctx_ptr);
if !ctx.circom.is_null() { if !ctx.circom.is_null() {
let circom = &mut Box::from_raw(ctx.circom as *mut CircomBn254); let circom = &mut Box::from_raw(ctx.circom);
let _ = Box::from_raw(circom.builder); let builder = Box::from_raw(circom.builder);
let _ = Box::from_raw(circom.proving_key); drop(builder);
if !circom.builder.is_null() { let proving_key = Box::from_raw((*ctx.cfg).proving_key);
circom.builder = std::ptr::null_mut() drop(proving_key);
}; let cfg = Box::from_raw((*ctx.cfg).cfg);
if !circom.proving_key.is_null() { drop(cfg);
circom.proving_key = std::ptr::null_mut()
};
ctx.circom = std::ptr::null_mut();
*ctx_ptr = std::ptr::null_mut(); *ctx_ptr = std::ptr::null_mut();
} }
} }
} }
#[no_mangle]
// Only use if the buffer was allocated by the ffi
pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) {
if !buff_ptr.is_null() {
let buff = Box::from_raw(*buff_ptr);
let data = Box::from_raw(slice_from_raw_parts_mut(buff.data as *mut u8, buff.len));
drop(data);
drop(buff);
}
}
#[no_mangle] #[no_mangle]
pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) { pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) {
if !proof_ptr.is_null() { if !proof_ptr.is_null() {
@ -181,8 +190,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
(*ctx_ptr).circom as *mut CircomBn254 (*ctx_ptr).circom as *mut CircomBn254
} }
/// # Safety
///
#[no_mangle] #[no_mangle]
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn prove_circuit( pub unsafe extern "C" fn prove_circuit(
@ -190,8 +197,9 @@ pub unsafe extern "C" fn prove_circuit(
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer, proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 { ) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| { let result = catch_unwind(AssertUnwindSafe(|| {
let ctx: &mut CircomCompatCtx = &mut *ctx_ptr;
let circom = &mut *to_circom(ctx_ptr); let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key); let proving_key = (*ctx.cfg).proving_key;
let rng = &mut (*ctx_ptr).rng; let rng = &mut (*ctx_ptr).rng;
let circuit = (*circom.builder) let circuit = (*circom.builder)
@ -200,7 +208,7 @@ pub unsafe extern "C" fn prove_circuit(
.map_err(|_| ERR_CIRCOM_BUILDER) .map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap(); .unwrap();
let circom_proof = GrothBn::prove(proving_key, circuit, rng) let circom_proof = GrothBn::prove(&*proving_key, circuit, rng)
.map_err(|_| ERR_MAKING_PROOF) .map_err(|_| ERR_MAKING_PROOF)
.unwrap(); .unwrap();
@ -210,8 +218,6 @@ pub unsafe extern "C" fn prove_circuit(
to_err_code(result) to_err_code(result)
} }
/// # Safety
///
#[no_mangle] #[no_mangle]
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn get_pub_inputs( pub unsafe extern "C" fn get_pub_inputs(
@ -236,8 +242,6 @@ pub unsafe extern "C" fn get_pub_inputs(
to_err_code(result) to_err_code(result)
} }
/// # Safety
///
#[no_mangle] #[no_mangle]
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn get_verifying_key( pub unsafe extern "C" fn get_verifying_key(
@ -245,8 +249,8 @@ pub unsafe extern "C" fn get_verifying_key(
vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer, vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 { ) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| { let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr); let ctx = &mut *ctx_ptr;
let proving_key = &(*circom.proving_key); let proving_key = &(*(*ctx.cfg).proving_key);
let vk = prepare_verifying_key(&proving_key.vk).vk; let vk = prepare_verifying_key(&proving_key.vk).vk;
*vk_ptr = Box::leak(Box::new((&vk).into())); *vk_ptr = Box::leak(Box::new((&vk).into()));
@ -255,8 +259,6 @@ pub unsafe extern "C" fn get_verifying_key(
to_err_code(result) to_err_code(result)
} }
/// # Safety
///
#[no_mangle] #[no_mangle]
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn verify_circuit( pub unsafe extern "C" fn verify_circuit(
@ -275,8 +277,6 @@ pub unsafe extern "C" fn verify_circuit(
to_err_code(result) to_err_code(result)
} }
/// # Safety
///
#[no_mangle] #[no_mangle]
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn push_input_u256_array( pub unsafe extern "C" fn push_input_u256_array(
@ -350,11 +350,17 @@ mod test {
let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap(); let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();
unsafe { unsafe {
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut(); let mut cfg_ptr: *mut CircomBn254Cfg = std::ptr::null_mut();
init_circom_compat( init_circom_config(
r1cs_path.as_ptr(), r1cs_path.as_ptr(),
wasm_path.as_ptr(), wasm_path.as_ptr(),
std::ptr::null(), std::ptr::null(),
&mut cfg_ptr,
);
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut();
init_circom_compat(
cfg_ptr,
&mut ctx_ptr, &mut ctx_ptr,
); );

View File

@ -23,13 +23,6 @@ pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13; pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14; pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
pub data: *const u8,
pub len: usize,
}
// Helper for converting a PrimeField to little endian byte slice // Helper for converting a PrimeField to little endian byte slice
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F { fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works"); let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");