minor cleanup and err handling

This commit is contained in:
Dmitriy Ryajov 2024-01-20 16:03:07 -06:00
parent 94cfed822c
commit 87fa8cdb96
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
1 changed files with 55 additions and 40 deletions

View File

@ -3,7 +3,7 @@ use std::{
ffi::{c_char, CStr},
fs::File,
os::raw::c_void,
panic::{catch_unwind, AssertUnwindSafe},
panic::{catch_unwind, AssertUnwindSafe}, io::Cursor,
};
use ark_bn254::{Bn254, Fr};
@ -28,6 +28,10 @@ pub const ERR_CANT_READ_ZKEY: i32 = 6;
pub const ERR_CIRCOM_BUILDER: i32 = 7;
pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8;
pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9;
pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
pub const ERR_GET_PUB_INPUTS: i32 = 11;
pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13;
#[derive(Debug, Clone)]
#[repr(C)]
@ -150,6 +154,7 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
}
#[no_mangle]
// Only use if the buffer was allocated by the ffi
pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) {
if !buff_ptr.is_null() {
let buff = &mut Box::from_raw(*buff_ptr);
@ -164,39 +169,6 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
(*ctx_ptr).circom as *mut CircomBn254
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn push_input_u256_array(
ctx_ptr: *mut CircomCompatCtx,
name_ptr: *const c_char,
input_ptr: *const u8,
len: usize,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let name = CStr::from_ptr(name_ptr)
.to_str()
.map_err(|_| ERR_INPUT_NAME)
.unwrap();
let input = {
let slice = std::slice::from_raw_parts(input_ptr, len);
slice
.chunks(U256::BYTES)
.map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap())
.collect::<Vec<U256>>()
};
let circom = &mut *to_circom(ctx_ptr);
input
.iter()
.for_each(|c| (*circom.builder).push_input(name, *c));
}));
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
@ -212,10 +184,19 @@ pub unsafe extern "C" fn prove_circuit(
let proving_key = &(*circom.proving_key);
let rng = &mut (*ctx_ptr).rng;
let circuit = (*circom.builder).clone().build().unwrap();
let circuit = (*circom.builder)
.clone()
.build()
.map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap();
let inputs = circuit.get_public_inputs().unwrap();
let proof = GrothBn::prove(&proving_key, circuit, rng).unwrap();
let inputs = circuit
.get_public_inputs()
.ok_or_else(|| ERR_GET_PUB_INPUTS)
.unwrap();
let proof = GrothBn::prove(&proving_key, circuit, rng)
.map_err(|_| ERR_MAKING_PROOF)
.unwrap();
let mut proof_bytes = Vec::new();
proof.serialize_compressed(&mut proof_bytes).unwrap();
@ -223,6 +204,7 @@ pub unsafe extern "C" fn prove_circuit(
let mut public_inputs_bytes = Vec::new();
inputs
.serialize_compressed(&mut public_inputs_bytes)
.map_err(|_| ERR_SERIALIZE_PROOF)
.unwrap();
// leak the buffers to avoid rust from freeing the pointed to data,
@ -278,13 +260,46 @@ pub unsafe extern "C" fn verify_circuit(
let pvk = prepare_verifying_key(&proving_key.vk);
GrothBn::verify_proof(&pvk, &proof, &public_inputs)
.map_err(|e| e.to_string())
.map_err(|_| ERR_FAILED_TO_VERIFY_PROOF)
.unwrap();
}));
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn push_input_u256_array(
ctx_ptr: *mut CircomCompatCtx,
name_ptr: *const c_char,
input_ptr: *const u8,
len: usize,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let name = CStr::from_ptr(name_ptr)
.to_str()
.map_err(|_| ERR_INPUT_NAME)
.unwrap();
let input = {
let slice = std::slice::from_raw_parts(input_ptr, len);
slice
.chunks(U256::BYTES)
.map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap())
.collect::<Vec<U256>>()
};
let circom = &mut *to_circom(ctx_ptr);
input
.iter()
.for_each(|c| (*circom.builder).push_input(name, *c));
}));
to_err_code(result)
}
macro_rules! build_fn
{
($name:tt, $($v:ident: $t:ty),*) => {
@ -296,7 +311,7 @@ macro_rules! build_fn
input: $($t),*
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let name = CStr::from_ptr(name_ptr).to_str().unwrap();
let name = CStr::from_ptr(name_ptr).to_str().map_err(|_| ERR_INPUT_NAME).unwrap();
let input = U256::from(input);
let circom = &mut *to_circom(ctx_ptr);
@ -323,7 +338,7 @@ mod test {
use super::*;
#[test]
fn groth16_proof() {
fn proof_verify() {
let r1cs_path = CString::new("./fixtures/mycircuit.r1cs".as_bytes()).unwrap();
let wasm_path = CString::new("./fixtures/mycircuit.wasm".as_bytes()).unwrap();