splitting cfg from builder construction
This commit is contained in:
parent
d6299d56ed
commit
2a08a6320b
110
src/ffi.rs
110
src/ffi.rs
|
@ -2,7 +2,6 @@ use std::{
|
||||||
any::Any,
|
any::Any,
|
||||||
ffi::{c_char, CStr},
|
ffi::{c_char, CStr},
|
||||||
fs::File,
|
fs::File,
|
||||||
os::raw::c_void,
|
|
||||||
panic::{catch_unwind, AssertUnwindSafe},
|
panic::{catch_unwind, AssertUnwindSafe},
|
||||||
ptr::slice_from_raw_parts_mut,
|
ptr::slice_from_raw_parts_mut,
|
||||||
};
|
};
|
||||||
|
@ -19,16 +18,22 @@ use crate::ffi_types::*;
|
||||||
type GrothBn = Groth16<Bn254>;
|
type GrothBn = Groth16<Bn254>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
// #[repr(C)]
|
struct CircomBn254Cfg {
|
||||||
|
cfg: *mut CircomConfig<Bn254>,
|
||||||
|
proving_key: *mut ProvingKey<Bn254>,
|
||||||
|
_marker: core::marker::PhantomData<(*mut CircomBn254Cfg, core::marker::PhantomPinned)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct CircomBn254 {
|
struct CircomBn254 {
|
||||||
builder: *mut CircomBuilder<Bn254>,
|
builder: *mut CircomBuilder<Bn254>,
|
||||||
proving_key: *mut ProvingKey<Bn254>,
|
|
||||||
_marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>,
|
_marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct CircomCompatCtx {
|
struct CircomCompatCtx {
|
||||||
circom: *mut c_void,
|
circom: *mut CircomBn254,
|
||||||
|
cfg: *mut CircomBn254Cfg,
|
||||||
rng: ThreadRng,
|
rng: ThreadRng,
|
||||||
_marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>,
|
_marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>,
|
||||||
}
|
}
|
||||||
|
@ -43,20 +48,16 @@ fn to_err_code(result: Result<(), Box<dyn Any + Send>>) -> i32 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(private_interfaces)]
|
#[allow(private_interfaces)]
|
||||||
pub unsafe extern "C" fn init_circom_compat(
|
pub unsafe extern "C" fn init_circom_config(
|
||||||
r1cs_path: *const c_char,
|
r1cs_path: *const c_char,
|
||||||
wasm_path: *const c_char,
|
wasm_path: *const c_char,
|
||||||
zkey_path: *const c_char,
|
zkey_path: *const c_char,
|
||||||
ctx_ptr: &mut *mut CircomCompatCtx,
|
cfg_ptr: &mut *mut CircomBn254Cfg,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let result = catch_unwind(AssertUnwindSafe(|| {
|
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||||
let mut rng = thread_rng(); // TODO: use a shared rng - how?
|
let cfg = CircomConfig::<Bn254>::new(
|
||||||
let builder = CircomBuilder::new(
|
|
||||||
CircomConfig::<Bn254>::new(
|
|
||||||
CStr::from_ptr(wasm_path)
|
CStr::from_ptr(wasm_path)
|
||||||
.to_str()
|
.to_str()
|
||||||
.map_err(|_| ERR_WASM_PATH)
|
.map_err(|_| ERR_WASM_PATH)
|
||||||
|
@ -67,8 +68,7 @@ pub unsafe extern "C" fn init_circom_compat(
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
)
|
)
|
||||||
.map_err(|_| ERR_CIRCOM_BUILDER)
|
.map_err(|_| ERR_CIRCOM_BUILDER)
|
||||||
.unwrap(),
|
.unwrap();
|
||||||
);
|
|
||||||
|
|
||||||
let proving_key = if !zkey_path.is_null() {
|
let proving_key = if !zkey_path.is_null() {
|
||||||
let mut file = File::open(
|
let mut file = File::open(
|
||||||
|
@ -84,6 +84,8 @@ pub unsafe extern "C" fn init_circom_compat(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.0
|
.0
|
||||||
} else {
|
} else {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let builder = CircomBuilder::new(cfg.clone());
|
||||||
Groth16::<Bn254>::generate_random_parameters_with_reduction::<_>(
|
Groth16::<Bn254>::generate_random_parameters_with_reduction::<_>(
|
||||||
builder.setup(),
|
builder.setup(),
|
||||||
&mut rng,
|
&mut rng,
|
||||||
|
@ -92,14 +94,35 @@ pub unsafe extern "C" fn init_circom_compat(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let circom_bn254_cfg = CircomBn254Cfg {
|
||||||
|
cfg: Box::into_raw(Box::new(cfg)),
|
||||||
|
proving_key: Box::into_raw(Box::new(proving_key)),
|
||||||
|
_marker: std::marker::PhantomData,
|
||||||
|
};
|
||||||
|
|
||||||
|
*cfg_ptr = Box::into_raw(Box::new(circom_bn254_cfg));
|
||||||
|
}));
|
||||||
|
|
||||||
|
to_err_code(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
#[allow(private_interfaces)]
|
||||||
|
pub unsafe extern "C" fn init_circom_compat(
|
||||||
|
cfg_ptr: *mut CircomBn254Cfg,
|
||||||
|
ctx_ptr: &mut *mut CircomCompatCtx,
|
||||||
|
) -> i32 {
|
||||||
|
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||||
|
let rng = thread_rng(); // TODO: use a shared rng - how?
|
||||||
|
let builder = CircomBuilder::new((*(*cfg_ptr).cfg).clone());
|
||||||
let circom_bn254 = CircomBn254 {
|
let circom_bn254 = CircomBn254 {
|
||||||
builder: Box::into_raw(Box::new(builder)),
|
builder: Box::into_raw(Box::new(builder)),
|
||||||
proving_key: Box::into_raw(Box::new(proving_key)),
|
|
||||||
_marker: core::marker::PhantomData,
|
_marker: core::marker::PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
let circom_compat_ctx = CircomCompatCtx {
|
let circom_compat_ctx = CircomCompatCtx {
|
||||||
circom: Box::into_raw(Box::new(circom_bn254)) as *mut c_void,
|
circom: Box::into_raw(Box::new(circom_bn254)),
|
||||||
|
cfg: cfg_ptr,
|
||||||
rng: rng,
|
rng: rng,
|
||||||
_marker: core::marker::PhantomData,
|
_marker: core::marker::PhantomData,
|
||||||
};
|
};
|
||||||
|
@ -116,32 +139,18 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
|
||||||
if !ctx_ptr.is_null() {
|
if !ctx_ptr.is_null() {
|
||||||
let ctx = &mut Box::from_raw(*ctx_ptr);
|
let ctx = &mut Box::from_raw(*ctx_ptr);
|
||||||
if !ctx.circom.is_null() {
|
if !ctx.circom.is_null() {
|
||||||
let circom = &mut Box::from_raw(ctx.circom as *mut CircomBn254);
|
let circom = &mut Box::from_raw(ctx.circom);
|
||||||
let _ = Box::from_raw(circom.builder);
|
let builder = Box::from_raw(circom.builder);
|
||||||
let _ = Box::from_raw(circom.proving_key);
|
drop(builder);
|
||||||
if !circom.builder.is_null() {
|
let proving_key = Box::from_raw((*ctx.cfg).proving_key);
|
||||||
circom.builder = std::ptr::null_mut()
|
drop(proving_key);
|
||||||
};
|
let cfg = Box::from_raw((*ctx.cfg).cfg);
|
||||||
if !circom.proving_key.is_null() {
|
drop(cfg);
|
||||||
circom.proving_key = std::ptr::null_mut()
|
|
||||||
};
|
|
||||||
ctx.circom = std::ptr::null_mut();
|
|
||||||
*ctx_ptr = std::ptr::null_mut();
|
*ctx_ptr = std::ptr::null_mut();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[no_mangle]
|
|
||||||
// Only use if the buffer was allocated by the ffi
|
|
||||||
pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) {
|
|
||||||
if !buff_ptr.is_null() {
|
|
||||||
let buff = Box::from_raw(*buff_ptr);
|
|
||||||
let data = Box::from_raw(slice_from_raw_parts_mut(buff.data as *mut u8, buff.len));
|
|
||||||
drop(data);
|
|
||||||
drop(buff);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) {
|
pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) {
|
||||||
if !proof_ptr.is_null() {
|
if !proof_ptr.is_null() {
|
||||||
|
@ -181,8 +190,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
|
||||||
(*ctx_ptr).circom as *mut CircomBn254
|
(*ctx_ptr).circom as *mut CircomBn254
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(private_interfaces)]
|
#[allow(private_interfaces)]
|
||||||
pub unsafe extern "C" fn prove_circuit(
|
pub unsafe extern "C" fn prove_circuit(
|
||||||
|
@ -190,8 +197,9 @@ pub unsafe extern "C" fn prove_circuit(
|
||||||
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
|
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let result = catch_unwind(AssertUnwindSafe(|| {
|
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||||
|
let ctx: &mut CircomCompatCtx = &mut *ctx_ptr;
|
||||||
let circom = &mut *to_circom(ctx_ptr);
|
let circom = &mut *to_circom(ctx_ptr);
|
||||||
let proving_key = &(*circom.proving_key);
|
let proving_key = (*ctx.cfg).proving_key;
|
||||||
let rng = &mut (*ctx_ptr).rng;
|
let rng = &mut (*ctx_ptr).rng;
|
||||||
|
|
||||||
let circuit = (*circom.builder)
|
let circuit = (*circom.builder)
|
||||||
|
@ -200,7 +208,7 @@ pub unsafe extern "C" fn prove_circuit(
|
||||||
.map_err(|_| ERR_CIRCOM_BUILDER)
|
.map_err(|_| ERR_CIRCOM_BUILDER)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let circom_proof = GrothBn::prove(proving_key, circuit, rng)
|
let circom_proof = GrothBn::prove(&*proving_key, circuit, rng)
|
||||||
.map_err(|_| ERR_MAKING_PROOF)
|
.map_err(|_| ERR_MAKING_PROOF)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -210,8 +218,6 @@ pub unsafe extern "C" fn prove_circuit(
|
||||||
to_err_code(result)
|
to_err_code(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(private_interfaces)]
|
#[allow(private_interfaces)]
|
||||||
pub unsafe extern "C" fn get_pub_inputs(
|
pub unsafe extern "C" fn get_pub_inputs(
|
||||||
|
@ -236,8 +242,6 @@ pub unsafe extern "C" fn get_pub_inputs(
|
||||||
to_err_code(result)
|
to_err_code(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(private_interfaces)]
|
#[allow(private_interfaces)]
|
||||||
pub unsafe extern "C" fn get_verifying_key(
|
pub unsafe extern "C" fn get_verifying_key(
|
||||||
|
@ -245,8 +249,8 @@ pub unsafe extern "C" fn get_verifying_key(
|
||||||
vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer,
|
vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let result = catch_unwind(AssertUnwindSafe(|| {
|
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||||
let circom = &mut *to_circom(ctx_ptr);
|
let ctx = &mut *ctx_ptr;
|
||||||
let proving_key = &(*circom.proving_key);
|
let proving_key = &(*(*ctx.cfg).proving_key);
|
||||||
let vk = prepare_verifying_key(&proving_key.vk).vk;
|
let vk = prepare_verifying_key(&proving_key.vk).vk;
|
||||||
|
|
||||||
*vk_ptr = Box::leak(Box::new((&vk).into()));
|
*vk_ptr = Box::leak(Box::new((&vk).into()));
|
||||||
|
@ -255,8 +259,6 @@ pub unsafe extern "C" fn get_verifying_key(
|
||||||
to_err_code(result)
|
to_err_code(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(private_interfaces)]
|
#[allow(private_interfaces)]
|
||||||
pub unsafe extern "C" fn verify_circuit(
|
pub unsafe extern "C" fn verify_circuit(
|
||||||
|
@ -275,8 +277,6 @@ pub unsafe extern "C" fn verify_circuit(
|
||||||
to_err_code(result)
|
to_err_code(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(private_interfaces)]
|
#[allow(private_interfaces)]
|
||||||
pub unsafe extern "C" fn push_input_u256_array(
|
pub unsafe extern "C" fn push_input_u256_array(
|
||||||
|
@ -350,11 +350,17 @@ mod test {
|
||||||
let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();
|
let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut();
|
let mut cfg_ptr: *mut CircomBn254Cfg = std::ptr::null_mut();
|
||||||
init_circom_compat(
|
init_circom_config(
|
||||||
r1cs_path.as_ptr(),
|
r1cs_path.as_ptr(),
|
||||||
wasm_path.as_ptr(),
|
wasm_path.as_ptr(),
|
||||||
std::ptr::null(),
|
std::ptr::null(),
|
||||||
|
&mut cfg_ptr,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut();
|
||||||
|
init_circom_compat(
|
||||||
|
cfg_ptr,
|
||||||
&mut ctx_ptr,
|
&mut ctx_ptr,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -23,13 +23,6 @@ pub const ERR_MAKING_PROOF: i32 = 12;
|
||||||
pub const ERR_SERIALIZE_PROOF: i32 = 13;
|
pub const ERR_SERIALIZE_PROOF: i32 = 13;
|
||||||
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
|
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[repr(C)]
|
|
||||||
pub struct Buffer {
|
|
||||||
pub data: *const u8,
|
|
||||||
pub len: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper for converting a PrimeField to little endian byte slice
|
// Helper for converting a PrimeField to little endian byte slice
|
||||||
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
|
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
|
||||||
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");
|
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");
|
||||||
|
|
Loading…
Reference in New Issue