splitting cfg from builder construction (#2)

This commit is contained in:
Dmitriy Ryajov 2024-01-26 10:40:39 -06:00 committed by GitHub
parent d6299d56ed
commit 60f6e5e059
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 69 deletions

View File

@ -2,7 +2,6 @@ use std::{
any::Any,
ffi::{c_char, CStr},
fs::File,
os::raw::c_void,
panic::{catch_unwind, AssertUnwindSafe},
ptr::slice_from_raw_parts_mut,
};
@ -19,16 +18,22 @@ use crate::ffi_types::*;
type GrothBn = Groth16<Bn254>;
#[derive(Debug, Clone)]
// #[repr(C)]
struct CircomBn254Cfg {
cfg: *mut CircomConfig<Bn254>,
proving_key: *mut ProvingKey<Bn254>,
_marker: core::marker::PhantomData<(*mut CircomBn254Cfg, core::marker::PhantomPinned)>,
}
#[derive(Debug, Clone)]
struct CircomBn254 {
builder: *mut CircomBuilder<Bn254>,
proving_key: *mut ProvingKey<Bn254>,
_marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>,
}
#[derive(Debug, Clone)]
struct CircomCompatCtx {
circom: *mut c_void,
circom: *mut CircomBn254,
cfg: *mut CircomBn254Cfg,
rng: ThreadRng,
_marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>,
}
@ -43,32 +48,27 @@ fn to_err_code(result: Result<(), Box<dyn Any + Send>>) -> i32 {
}
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn init_circom_compat(
pub unsafe extern "C" fn init_circom_config(
r1cs_path: *const c_char,
wasm_path: *const c_char,
zkey_path: *const c_char,
ctx_ptr: &mut *mut CircomCompatCtx,
cfg_ptr: &mut *mut CircomBn254Cfg,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let mut rng = thread_rng(); // TODO: use a shared rng - how?
let builder = CircomBuilder::new(
CircomConfig::<Bn254>::new(
CStr::from_ptr(wasm_path)
.to_str()
.map_err(|_| ERR_WASM_PATH)
.unwrap(),
CStr::from_ptr(r1cs_path)
.to_str()
.map_err(|_| ERR_R1CS_PATH)
.unwrap(),
)
.map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap(),
);
let cfg = CircomConfig::<Bn254>::new(
CStr::from_ptr(wasm_path)
.to_str()
.map_err(|_| ERR_WASM_PATH)
.unwrap(),
CStr::from_ptr(r1cs_path)
.to_str()
.map_err(|_| ERR_R1CS_PATH)
.unwrap(),
)
.map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap();
let proving_key = if !zkey_path.is_null() {
let mut file = File::open(
@ -84,6 +84,8 @@ pub unsafe extern "C" fn init_circom_compat(
.unwrap()
.0
} else {
let mut rng = thread_rng();
let builder = CircomBuilder::new(cfg.clone());
Groth16::<Bn254>::generate_random_parameters_with_reduction::<_>(
builder.setup(),
&mut rng,
@ -92,14 +94,35 @@ pub unsafe extern "C" fn init_circom_compat(
.unwrap()
};
let circom_bn254_cfg = CircomBn254Cfg {
cfg: Box::into_raw(Box::new(cfg)),
proving_key: Box::into_raw(Box::new(proving_key)),
_marker: std::marker::PhantomData,
};
*cfg_ptr = Box::into_raw(Box::new(circom_bn254_cfg));
}));
to_err_code(result)
}
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn init_circom_compat(
cfg_ptr: *mut CircomBn254Cfg,
ctx_ptr: &mut *mut CircomCompatCtx,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let rng = thread_rng(); // TODO: use a shared rng - how?
let builder = CircomBuilder::new((*(*cfg_ptr).cfg).clone());
let circom_bn254 = CircomBn254 {
builder: Box::into_raw(Box::new(builder)),
proving_key: Box::into_raw(Box::new(proving_key)),
_marker: core::marker::PhantomData,
};
let circom_compat_ctx = CircomCompatCtx {
circom: Box::into_raw(Box::new(circom_bn254)) as *mut c_void,
circom: Box::into_raw(Box::new(circom_bn254)),
cfg: cfg_ptr,
rng: rng,
_marker: core::marker::PhantomData,
};
@ -116,32 +139,18 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
if !ctx_ptr.is_null() {
let ctx = &mut Box::from_raw(*ctx_ptr);
if !ctx.circom.is_null() {
let circom = &mut Box::from_raw(ctx.circom as *mut CircomBn254);
let _ = Box::from_raw(circom.builder);
let _ = Box::from_raw(circom.proving_key);
if !circom.builder.is_null() {
circom.builder = std::ptr::null_mut()
};
if !circom.proving_key.is_null() {
circom.proving_key = std::ptr::null_mut()
};
ctx.circom = std::ptr::null_mut();
let circom = &mut Box::from_raw(ctx.circom);
let builder = Box::from_raw(circom.builder);
drop(builder);
let proving_key = Box::from_raw((*ctx.cfg).proving_key);
drop(proving_key);
let cfg = Box::from_raw((*ctx.cfg).cfg);
drop(cfg);
*ctx_ptr = std::ptr::null_mut();
}
}
}
#[no_mangle]
// Only use if the buffer was allocated by the ffi
pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) {
if !buff_ptr.is_null() {
let buff = Box::from_raw(*buff_ptr);
let data = Box::from_raw(slice_from_raw_parts_mut(buff.data as *mut u8, buff.len));
drop(data);
drop(buff);
}
}
#[no_mangle]
pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) {
if !proof_ptr.is_null() {
@ -181,8 +190,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
(*ctx_ptr).circom as *mut CircomBn254
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn prove_circuit(
@ -190,8 +197,9 @@ pub unsafe extern "C" fn prove_circuit(
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let ctx: &mut CircomCompatCtx = &mut *ctx_ptr;
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let proving_key = (*ctx.cfg).proving_key;
let rng = &mut (*ctx_ptr).rng;
let circuit = (*circom.builder)
@ -200,7 +208,7 @@ pub unsafe extern "C" fn prove_circuit(
.map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap();
let circom_proof = GrothBn::prove(proving_key, circuit, rng)
let circom_proof = GrothBn::prove(&*proving_key, circuit, rng)
.map_err(|_| ERR_MAKING_PROOF)
.unwrap();
@ -210,8 +218,6 @@ pub unsafe extern "C" fn prove_circuit(
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn get_pub_inputs(
@ -236,8 +242,6 @@ pub unsafe extern "C" fn get_pub_inputs(
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn get_verifying_key(
@ -245,8 +249,8 @@ pub unsafe extern "C" fn get_verifying_key(
vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let ctx = &mut *ctx_ptr;
let proving_key = &(*(*ctx.cfg).proving_key);
let vk = prepare_verifying_key(&proving_key.vk).vk;
*vk_ptr = Box::leak(Box::new((&vk).into()));
@ -255,8 +259,6 @@ pub unsafe extern "C" fn get_verifying_key(
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn verify_circuit(
@ -275,8 +277,6 @@ pub unsafe extern "C" fn verify_circuit(
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn push_input_u256_array(
@ -350,11 +350,17 @@ mod test {
let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();
unsafe {
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut();
init_circom_compat(
let mut cfg_ptr: *mut CircomBn254Cfg = std::ptr::null_mut();
init_circom_config(
r1cs_path.as_ptr(),
wasm_path.as_ptr(),
std::ptr::null(),
&mut cfg_ptr,
);
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut();
init_circom_compat(
cfg_ptr,
&mut ctx_ptr,
);

View File

@ -23,13 +23,6 @@ pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
pub data: *const u8,
pub len: usize,
}
// Helper for converting a PrimeField to little endian byte slice
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");