This commit is contained in:
Dmitriy Ryajov 2025-03-06 21:27:16 -06:00
parent d2023e419f
commit b05238f09a
3 changed files with 178 additions and 69 deletions

View File

@ -8,6 +8,10 @@ opt-level = 3 # Use slightly better optimizations.
debug = true # Generate debug info.
debug-assertions = true # Enable debug assertions.
[profile.release]
opt-level = 3 # Maximum optimization level
debug = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
crate-type = [

View File

@ -3,7 +3,6 @@ use std::{
ffi::{c_char, CStr},
fs::File,
panic::{catch_unwind, AssertUnwindSafe},
ptr::slice_from_raw_parts_mut,
};
use crate::ffi_types::*;
@ -13,6 +12,7 @@ use ark_crypto_primitives::snark::SNARK;
use ark_groth16::{prepare_verifying_key, Groth16, ProvingKey};
use ark_std::rand::thread_rng;
use ruint::aliases::U256;
use std::sync::{Arc, Mutex};
type GrothBn = Groth16<Bn254, CircomReduction>;
@ -35,21 +35,18 @@ pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
struct CircomBn254Cfg {
cfg: *mut CircomConfig<Bn254>,
proving_key: *mut ProvingKey<Bn254>,
_marker: core::marker::PhantomData<(*mut CircomBn254Cfg, core::marker::PhantomPinned)>,
cfg: Arc<Mutex<CircomConfig<Bn254>>>,
proving_key: Arc<Mutex<ProvingKey<Bn254>>>,
}
#[derive(Debug, Clone)]
struct CircomBn254 {
builder: *mut CircomBuilder<Bn254>,
_marker: core::marker::PhantomData<(*mut CircomBn254, core::marker::PhantomPinned)>,
builder: Arc<Mutex<CircomBuilder<Bn254>>>,
}
#[derive(Debug, Clone)]
struct CircomCompatCtx {
circom: *mut CircomBn254,
_marker: core::marker::PhantomData<(*mut CircomCompatCtx, core::marker::PhantomPinned)>,
circom: Arc<Mutex<CircomBn254>>,
}
fn to_err_code(result: Result<i32, Box<dyn Any + Send>>) -> i32 {
@ -108,9 +105,8 @@ pub unsafe extern "C" fn init_circom_config_with_checks(
};
let circom_bn254_cfg = CircomBn254Cfg {
cfg: Box::into_raw(Box::new(cfg)),
proving_key: Box::into_raw(Box::new(proving_key)),
_marker: std::marker::PhantomData,
cfg: Arc::new(Mutex::new(cfg)),
proving_key: Arc::new(Mutex::new(proving_key)),
};
*cfg_ptr = Box::into_raw(Box::new(circom_bn254_cfg));
@ -139,15 +135,13 @@ pub unsafe extern "C" fn init_circom_compat(
ctx_ptr: &mut *mut CircomCompatCtx,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let builder = CircomBuilder::new((*(*cfg_ptr).cfg).clone()); // clone the config
let builder = CircomBuilder::new((*(*cfg_ptr).cfg.lock().unwrap()).clone()); // clone the config
let circom_bn254 = CircomBn254 {
builder: Box::into_raw(Box::new(builder)),
_marker: core::marker::PhantomData,
builder: Arc::new(Mutex::new(builder)),
};
let circom_compat_ctx = CircomCompatCtx {
circom: Box::into_raw(Box::new(circom_bn254)),
_marker: core::marker::PhantomData,
circom: Arc::new(Mutex::new(circom_bn254)),
};
*ctx_ptr = Box::into_raw(Box::new(circom_compat_ctx));
@ -162,14 +156,7 @@ pub unsafe extern "C" fn init_circom_compat(
#[allow(private_interfaces)]
pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCtx) {
if !ctx_ptr.is_null() {
let ctx = &mut Box::from_raw(*ctx_ptr);
if !ctx.circom.is_null() {
let circom = &mut Box::from_raw(ctx.circom);
let builder = Box::from_raw(circom.builder);
drop(builder);
}
drop(Box::from_raw(*ctx_ptr));
*ctx_ptr = std::ptr::null_mut();
}
}
@ -179,8 +166,6 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
pub unsafe extern "C" fn release_cfg(cfg_ptr: &mut *mut CircomBn254Cfg) {
if !cfg_ptr.is_null() && !(*cfg_ptr).is_null() {
let cfg = Box::from_raw(*cfg_ptr);
drop(Box::from_raw((*cfg).proving_key));
drop(Box::from_raw((*cfg).cfg));
drop(cfg);
*cfg_ptr = std::ptr::null_mut();
}
@ -199,11 +184,7 @@ pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) {
pub unsafe extern "C" fn release_inputs(inputs_ptr: &mut *mut Inputs) {
if !inputs_ptr.is_null() {
let inputs = Box::from_raw(*inputs_ptr);
let elms = Box::from_raw(slice_from_raw_parts_mut(
inputs.elms as *mut [u8; 32],
inputs.len,
));
drop(elms);
inputs.free();
drop(inputs);
*inputs_ptr = std::ptr::null_mut();
}
@ -214,16 +195,15 @@ pub unsafe extern "C" fn release_inputs(inputs_ptr: &mut *mut Inputs) {
pub unsafe extern "C" fn release_key(key_ptr: &mut *mut VerifyingKey) {
if !key_ptr.is_null() {
let key = Box::from_raw(*key_ptr);
let ic= Box::from_raw(slice_from_raw_parts_mut(key.ic as *mut G1, key.ic_len));
drop(ic);
key.free();
drop(key);
*key_ptr = std::ptr::null_mut();
}
}
unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
(*ctx_ptr).circom as *mut CircomBn254
}
// unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
// (*ctx_ptr).circom.lock().unwrap() as *mut CircomBn254
// }
#[no_mangle]
#[allow(private_interfaces)]
@ -233,11 +213,11 @@ pub unsafe extern "C" fn prove_circuit(
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let proving_key = (*(*cfg_ptr).proving_key).clone();
let circom = ((*ctx_ptr).circom.lock().unwrap()).clone();
let proving_key = (*(*cfg_ptr).proving_key.lock().unwrap()).clone();
let mut rng = thread_rng();
let circuit = (*circom.builder)
let circuit = (*circom.builder.lock().unwrap())
.clone()
.build()
.map_err(|_| ERR_CIRCOM_BUILDER)
@ -247,7 +227,7 @@ pub unsafe extern "C" fn prove_circuit(
.map_err(|_| ERR_MAKING_PROOF)
.unwrap();
*proof_ptr = Box::leak(Box::new((&circom_proof).into()));
*proof_ptr = Box::into_raw(Box::new((&circom_proof).into()));
ERR_OK
}));
@ -262,8 +242,8 @@ pub unsafe extern "C" fn get_pub_inputs(
inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let circuit = (*circom.builder)
let circom = (*ctx_ptr).circom.lock().unwrap();
let circuit = (*circom.builder.lock().unwrap())
.clone()
.build()
.map_err(|_| ERR_CIRCOM_BUILDER)
@ -273,7 +253,7 @@ pub unsafe extern "C" fn get_pub_inputs(
.get_public_inputs()
.ok_or_else(|| ERR_GET_PUB_INPUTS)
.unwrap();
*inputs_ptr = Box::leak(Box::new(inputs.as_slice().into()));
*inputs_ptr = Box::into_raw(Box::new(inputs.as_slice().into()));
ERR_OK
}));
@ -289,7 +269,7 @@ pub unsafe extern "C" fn get_verifying_key(
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let ctx = &mut *cfg_ptr;
let proving_key = &(*(*ctx).proving_key);
let proving_key = &(*(*ctx).proving_key.lock().unwrap());
let vk = prepare_verifying_key(&proving_key.vk).vk;
*vk_ptr = Box::into_raw(Box::new((&vk).into()));
@ -308,7 +288,7 @@ pub unsafe extern "C" fn verify_circuit(
pvk: *const VerifyingKey,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let inputs_vec: Vec<Fr> = (*inputs).into();
let inputs_vec: Vec<Fr> = (*inputs).clone().into();
let prepared_key = prepare_verifying_key(&(*pvk).into());
let passed = GrothBn::verify_proof(&prepared_key, &(*proof).into(), inputs_vec.as_slice())
@ -350,10 +330,10 @@ pub unsafe extern "C" fn push_input_u256_array(
.map(|c| U256::try_from_le_slice(c).ok_or(ERR_INVALID_INPUT).unwrap())
.collect::<Vec<U256>>();
let circom = &mut *to_circom(ctx_ptr);
let circom = (*ctx_ptr).circom.lock().unwrap();
inputs
.iter()
.for_each(|c| (*circom.builder).push_input(name, *c));
.for_each(|c| (*circom.builder.lock().unwrap()).push_input(name, *c));
ERR_OK
}));
@ -375,8 +355,8 @@ macro_rules! build_fn
let name = CStr::from_ptr(name_ptr).to_str().map_err(|_| ERR_INPUT_NAME).unwrap();
let input = U256::from(input);
let circom = &mut *to_circom(ctx_ptr);
(*circom.builder).push_input(name, input);
let circom = (*ctx_ptr).circom.lock().unwrap();
(*circom.builder.lock().unwrap()).push_input(name, input);
ERR_OK
}));
@ -400,7 +380,6 @@ mod test {
use std::ffi::CString;
#[test]
#[ignore] // TODO: getting alignment issues for this circuit, need to investigate further
fn proof_verify() {
let r1cs_path = CString::new("./fixtures/circom2_multiplier2.r1cs".as_bytes()).unwrap();
let wasm_path = CString::new("./fixtures/circom2_multiplier2.wasm".as_bytes()).unwrap();
@ -519,4 +498,102 @@ mod test {
assert!(cfg_ptr == std::ptr::null_mut());
};
}
// Wrapper to make raw pointers safely sendable between threads
// This is safe because we know the pointers are protected by Arc<Mutex<>> internally
struct ThreadSafePointer<T>(*const T);
unsafe impl<T> Send for ThreadSafePointer<T> {}
#[test]
fn multithreaded_prove_and_verify() {
use std::sync::{Arc, Barrier};
use std::thread;
let r1cs_path = CString::new("./fixtures/circom2_multiplier2.r1cs".as_bytes()).unwrap();
let wasm_path = CString::new("./fixtures/circom2_multiplier2.wasm".as_bytes()).unwrap();
let zkey_path = CString::new("./fixtures/test.zkey".as_bytes()).unwrap();
unsafe {
// Setup in the main thread
let mut cfg_ptr: *mut CircomBn254Cfg = std::ptr::null_mut();
init_circom_config(
r1cs_path.as_ptr(),
wasm_path.as_ptr(),
zkey_path.as_ptr(),
&mut cfg_ptr,
);
assert!(cfg_ptr != std::ptr::null_mut());
let mut ctx_ptr: *mut CircomCompatCtx = std::ptr::null_mut();
init_circom_compat(cfg_ptr, &mut ctx_ptr);
assert!(ctx_ptr != std::ptr::null_mut());
// Push inputs
let a = CString::new("a".as_bytes()).unwrap();
push_input_i8(ctx_ptr, a.as_ptr(), 3);
let b = CString::new("b".as_bytes()).unwrap();
push_input_i8(ctx_ptr, b.as_ptr(), 11);
// Create a barrier to synchronize threads
let barrier = Arc::new(Barrier::new(2));
let barrier_clone = barrier.clone();
// Wrap the pointers in a thread-safe container
let thread_cfg_ptr = Arc::new(Mutex::new(ThreadSafePointer(
cfg_ptr as *const CircomBn254Cfg,
)));
let thread_ctx_ptr = Arc::new(Mutex::new(ThreadSafePointer(
ctx_ptr as *const CircomCompatCtx,
)));
let thread_cfg_ptr_clone = Arc::clone(&thread_cfg_ptr);
let thread_ctx_ptr_clone = Arc::clone(&thread_ctx_ptr);
// Spawn thread for both proving and verification
let prover_thread = thread::spawn(move || {
// Unwrap the pointers
let thread_cfg_ptr = thread_cfg_ptr_clone.lock().unwrap().0 as *mut CircomBn254Cfg;
let thread_ctx_ptr = thread_ctx_ptr_clone.lock().unwrap().0 as *mut CircomCompatCtx;
// Create pointers for proof, inputs, and verification key
let mut proof_ptr: *mut Proof = std::ptr::null_mut();
let mut inputs_ptr: *mut Inputs = std::ptr::null_mut();
let mut vk_ptr: *mut VerifyingKey = std::ptr::null_mut();
// Get public inputs
assert!(get_pub_inputs(thread_ctx_ptr, &mut inputs_ptr) == ERR_OK);
assert!(inputs_ptr != std::ptr::null_mut());
// Generate proof
assert!(prove_circuit(thread_cfg_ptr, thread_ctx_ptr, &mut proof_ptr) == ERR_OK);
assert!(proof_ptr != std::ptr::null_mut());
// Get verification key
assert!(get_verifying_key(thread_cfg_ptr, &mut vk_ptr) == ERR_OK);
assert!(vk_ptr != std::ptr::null_mut());
// Verify the proof in the same thread
assert!(verify_circuit(&(*proof_ptr), &(*inputs_ptr), &(*vk_ptr)) == ERR_OK);
// Clean up resources in this thread
release_proof(&mut proof_ptr);
release_inputs(&mut inputs_ptr);
release_key(&mut vk_ptr);
// Signal that we're done
barrier_clone.wait();
});
// Wait for the prover thread to complete
barrier.wait();
// Clean up remaining resources in the main thread
release_circom_compat(&mut ctx_ptr);
release_cfg(&mut cfg_ptr);
// Wait for the thread to finish
prover_thread.join().unwrap();
};
}
}

View File

@ -1,4 +1,4 @@
use std::ptr::slice_from_raw_parts;
use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut};
use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine};
use ark_ff::{BigInteger, PrimeField};
@ -6,7 +6,7 @@ use ark_serialize::CanonicalDeserialize;
use ark_std::Zero;
// Helper for converting a PrimeField to little endian byte slice
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
fn slice_to_point<F: PrimeField>(point: [u8; 32]) -> F {
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");
F::from_bigint(bigint).unwrap()
}
@ -90,8 +90,8 @@ impl From<&ark_groth16::Proof<Bn254>> for Proof {
impl From<G1> for G1Affine {
fn from(src: G1) -> Self {
let x: Fq = slice_to_point(&src.x);
let y: Fq = slice_to_point(&src.y);
let x: Fq = slice_to_point(src.x);
let y: Fq = slice_to_point(src.y);
if x.is_zero() && y.is_zero() {
G1Affine::identity()
} else {
@ -102,12 +102,12 @@ impl From<G1> for G1Affine {
impl From<G2> for G2Affine {
fn from(src: G2) -> G2Affine {
let c0 = slice_to_point(&src.x[0]);
let c1 = slice_to_point(&src.x[1]);
let c0 = slice_to_point(src.x[0]);
let c1 = slice_to_point(src.x[1]);
let x = Fq2::new(c0, c1);
let c0 = slice_to_point(&src.y[0]);
let c1 = slice_to_point(&src.y[1]);
let c0 = slice_to_point(src.y[0]);
let c1 = slice_to_point(src.y[1]);
let y = Fq2::new(c0, c1);
if x.is_zero() && y.is_zero() {
@ -136,7 +136,7 @@ impl From<VerifyingKey> for ark_groth16::VerifyingKey<Bn254> {
gamma_g2: src.gamma2.into(),
delta_g2: src.delta2.into(),
gamma_abc_g1: unsafe {
std::slice::from_raw_parts(src.ic, src.ic_len)
(*slice_from_raw_parts(src.ic, src.ic_len))
.iter()
.map(|p| (*p).into())
.collect()
@ -145,10 +145,39 @@ impl From<VerifyingKey> for ark_groth16::VerifyingKey<Bn254> {
}
}
impl VerifyingKey {
pub fn free(mut self) {
unsafe {
if !self.ic.is_null() && self.ic_len > 0 {
drop(Box::from_raw(slice_from_raw_parts_mut(
self.ic as *mut G1,
self.ic_len,
)));
self.ic = std::ptr::null();
self.ic_len = 0;
}
}
}
}
impl Inputs {
pub fn free(mut self) {
unsafe {
if !self.elms.is_null() && self.len > 0 {
drop(Box::from_raw(slice_from_raw_parts_mut(
self.elms as *mut [u8; 32],
self.len,
)));
self.elms = std::ptr::null();
self.len = 0;
}
}
}
}
impl From<&ark_groth16::VerifyingKey<Bn254>> for VerifyingKey {
fn from(vk: &ark_groth16::VerifyingKey<Bn254>) -> Self {
let mut ic: Vec<G1> = vk.gamma_abc_g1.iter().map(|p| p.into()).collect();
ic.shrink_to_fit();
let ic: Vec<G1> = vk.gamma_abc_g1.iter().map(|p| p.into()).collect();
let len = ic.len();
Self {
@ -164,12 +193,10 @@ impl From<&ark_groth16::VerifyingKey<Bn254>> for VerifyingKey {
impl From<&[Fr]> for Inputs {
fn from(src: &[Fr]) -> Self {
let mut els: Vec<[u8; 32]> = src.iter().map(|point| point_to_slice(*point)).collect();
els.shrink_to_fit();
let els: Vec<[u8; 32]> = src.iter().map(|point| point_to_slice(*point)).collect();
let len = els.len();
Self {
elms: Box::leak(els.into_boxed_slice()).as_ptr(),
elms: Box::into_raw(els.into_boxed_slice()) as *const [u8; 32],
len: len,
}
}
@ -178,9 +205,9 @@ impl From<&[Fr]> for Inputs {
impl From<Inputs> for Vec<Fr> {
fn from(src: Inputs) -> Self {
let els: Vec<Fr> = unsafe {
(&*slice_from_raw_parts(src.elms, src.len))
(*slice_from_raw_parts(src.elms, src.len))
.iter()
.map(|point| slice_to_point(point))
.map(|point| slice_to_point(*point))
.collect()
};
@ -216,7 +243,7 @@ mod test {
fn convert_fq() {
let el = fq();
let el2 = point_to_slice(el);
let el3: Fq = slice_to_point(&el2);
let el3: Fq = slice_to_point(el2);
let el4 = point_to_slice(el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
@ -226,7 +253,7 @@ mod test {
fn convert_fr() {
let el = fr();
let el2 = point_to_slice(el);
let el3: Fr = slice_to_point(&el2);
let el3: Fr = slice_to_point(el2);
let el4 = point_to_slice(el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
@ -264,6 +291,7 @@ mod test {
let vk_ffi = &VerifyingKey::from(&vk);
let ark_vk: ark_groth16::VerifyingKey<Bn254> = (*vk_ffi).into();
assert_eq!(ark_vk, vk);
vk_ffi.free();
}
#[test]