minor cleanup and err handling
This commit is contained in:
parent
94cfed822c
commit
87fa8cdb96
95
src/ffi.rs
95
src/ffi.rs
|
@ -3,7 +3,7 @@ use std::{
|
|||
ffi::{c_char, CStr},
|
||||
fs::File,
|
||||
os::raw::c_void,
|
||||
panic::{catch_unwind, AssertUnwindSafe},
|
||||
panic::{catch_unwind, AssertUnwindSafe}, io::Cursor,
|
||||
};
|
||||
|
||||
use ark_bn254::{Bn254, Fr};
|
||||
|
@ -28,6 +28,10 @@ pub const ERR_CANT_READ_ZKEY: i32 = 6;
|
|||
pub const ERR_CIRCOM_BUILDER: i32 = 7;
|
||||
pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8;
|
||||
pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9;
|
||||
pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
|
||||
pub const ERR_GET_PUB_INPUTS: i32 = 11;
|
||||
pub const ERR_MAKING_PROOF: i32 = 12;
|
||||
pub const ERR_SERIALIZE_PROOF: i32 = 13;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(C)]
|
||||
|
@ -150,6 +154,7 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
|
|||
}
|
||||
|
||||
#[no_mangle]
|
||||
// Only use if the buffer was allocated by the ffi
|
||||
pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) {
|
||||
if !buff_ptr.is_null() {
|
||||
let buff = &mut Box::from_raw(*buff_ptr);
|
||||
|
@ -164,39 +169,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
|
|||
(*ctx_ptr).circom as *mut CircomBn254
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
#[no_mangle]
|
||||
#[allow(private_interfaces)]
|
||||
pub unsafe extern "C" fn push_input_u256_array(
|
||||
ctx_ptr: *mut CircomCompatCtx,
|
||||
name_ptr: *const c_char,
|
||||
input_ptr: *const u8,
|
||||
len: usize,
|
||||
) -> i32 {
|
||||
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||
let name = CStr::from_ptr(name_ptr)
|
||||
.to_str()
|
||||
.map_err(|_| ERR_INPUT_NAME)
|
||||
.unwrap();
|
||||
|
||||
let input = {
|
||||
let slice = std::slice::from_raw_parts(input_ptr, len);
|
||||
slice
|
||||
.chunks(U256::BYTES)
|
||||
.map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap())
|
||||
.collect::<Vec<U256>>()
|
||||
};
|
||||
|
||||
let circom = &mut *to_circom(ctx_ptr);
|
||||
input
|
||||
.iter()
|
||||
.for_each(|c| (*circom.builder).push_input(name, *c));
|
||||
}));
|
||||
|
||||
to_err_code(result)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
#[no_mangle]
|
||||
|
@ -212,10 +184,19 @@ pub unsafe extern "C" fn prove_circuit(
|
|||
let proving_key = &(*circom.proving_key);
|
||||
let rng = &mut (*ctx_ptr).rng;
|
||||
|
||||
let circuit = (*circom.builder).clone().build().unwrap();
|
||||
let circuit = (*circom.builder)
|
||||
.clone()
|
||||
.build()
|
||||
.map_err(|_| ERR_CIRCOM_BUILDER)
|
||||
.unwrap();
|
||||
|
||||
let inputs = circuit.get_public_inputs().unwrap();
|
||||
let proof = GrothBn::prove(&proving_key, circuit, rng).unwrap();
|
||||
let inputs = circuit
|
||||
.get_public_inputs()
|
||||
.ok_or_else(|| ERR_GET_PUB_INPUTS)
|
||||
.unwrap();
|
||||
let proof = GrothBn::prove(&proving_key, circuit, rng)
|
||||
.map_err(|_| ERR_MAKING_PROOF)
|
||||
.unwrap();
|
||||
|
||||
let mut proof_bytes = Vec::new();
|
||||
proof.serialize_compressed(&mut proof_bytes).unwrap();
|
||||
|
@ -223,6 +204,7 @@ pub unsafe extern "C" fn prove_circuit(
|
|||
let mut public_inputs_bytes = Vec::new();
|
||||
inputs
|
||||
.serialize_compressed(&mut public_inputs_bytes)
|
||||
.map_err(|_| ERR_SERIALIZE_PROOF)
|
||||
.unwrap();
|
||||
|
||||
// leak the buffers to avoid rust from freeing the pointed to data,
|
||||
|
@ -278,13 +260,46 @@ pub unsafe extern "C" fn verify_circuit(
|
|||
let pvk = prepare_verifying_key(&proving_key.vk);
|
||||
|
||||
GrothBn::verify_proof(&pvk, &proof, &public_inputs)
|
||||
.map_err(|e| e.to_string())
|
||||
.map_err(|_| ERR_FAILED_TO_VERIFY_PROOF)
|
||||
.unwrap();
|
||||
}));
|
||||
|
||||
to_err_code(result)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
#[no_mangle]
|
||||
#[allow(private_interfaces)]
|
||||
pub unsafe extern "C" fn push_input_u256_array(
|
||||
ctx_ptr: *mut CircomCompatCtx,
|
||||
name_ptr: *const c_char,
|
||||
input_ptr: *const u8,
|
||||
len: usize,
|
||||
) -> i32 {
|
||||
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||
let name = CStr::from_ptr(name_ptr)
|
||||
.to_str()
|
||||
.map_err(|_| ERR_INPUT_NAME)
|
||||
.unwrap();
|
||||
|
||||
let input = {
|
||||
let slice = std::slice::from_raw_parts(input_ptr, len);
|
||||
slice
|
||||
.chunks(U256::BYTES)
|
||||
.map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap())
|
||||
.collect::<Vec<U256>>()
|
||||
};
|
||||
|
||||
let circom = &mut *to_circom(ctx_ptr);
|
||||
input
|
||||
.iter()
|
||||
.for_each(|c| (*circom.builder).push_input(name, *c));
|
||||
}));
|
||||
|
||||
to_err_code(result)
|
||||
}
|
||||
|
||||
macro_rules! build_fn
|
||||
{
|
||||
($name:tt, $($v:ident: $t:ty),*) => {
|
||||
|
@ -296,7 +311,7 @@ macro_rules! build_fn
|
|||
input: $($t),*
|
||||
) -> i32 {
|
||||
let result = catch_unwind(AssertUnwindSafe(|| {
|
||||
let name = CStr::from_ptr(name_ptr).to_str().unwrap();
|
||||
let name = CStr::from_ptr(name_ptr).to_str().map_err(|_| ERR_INPUT_NAME).unwrap();
|
||||
let input = U256::from(input);
|
||||
|
||||
let circom = &mut *to_circom(ctx_ptr);
|
||||
|
@ -323,7 +338,7 @@ mod test {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn groth16_proof() {
|
||||
fn proof_verify() {
|
||||
let r1cs_path = CString::new("./fixtures/mycircuit.r1cs".as_bytes()).unwrap();
|
||||
let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();
|
||||
|
||||
|
|
Loading…
Reference in New Issue