Rework ffi (#1)

* rework ffi to export serialized types

* add release methods

* cleanup
This commit is contained in:
Dmitriy Ryajov 2024-01-25 15:44:39 -06:00 committed by GitHub
parent aed402f6ee
commit d6299d56ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 422 additions and 116 deletions

View File

@ -25,3 +25,5 @@ ark-poly = { version = "=0.4.1", default-features = false, features = ["parallel
ark-relations = { version = "=0.4.0", default-features = false }
ark-serialize = { version = "=0.4.1", default-features = false }
ruint = { version = "1.7.0", features = ["serde", "num-bigint", "ark-ff"] }
num-bigint = "0.4.3"

View File

@ -4,43 +4,20 @@ use std::{
fs::File,
os::raw::c_void,
panic::{catch_unwind, AssertUnwindSafe},
ptr::slice_from_raw_parts_mut,
};
use ark_bn254::{Bn254, Fr};
use ark_circom::{read_zkey, CircomBuilder, CircomConfig};
use ark_crypto_primitives::snark::SNARK;
use ark_groth16::{prepare_verifying_key, Groth16, Proof, ProvingKey};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use ark_groth16::{prepare_verifying_key, Groth16, ProvingKey};
use ark_std::rand::{rngs::ThreadRng, thread_rng};
use ruint::aliases::U256;
use crate::ffi_types::*;
type GrothBn = Groth16<Bn254>;
pub const ERR_UNKNOWN: i32 = -1;
pub const ERR_OK: i32 = 0;
pub const ERR_WASM_PATH: i32 = 1;
pub const ERR_R1CS_PATH: i32 = 2;
pub const ERR_ZKEY_PATH: i32 = 3;
pub const ERR_INPUT_NAME: i32 = 4;
pub const ERR_INVALID_INPUT: i32 = 5;
pub const ERR_CANT_READ_ZKEY: i32 = 6;
pub const ERR_CIRCOM_BUILDER: i32 = 7;
pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8;
pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9;
pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
pub const ERR_GET_PUB_INPUTS: i32 = 11;
pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
data: *const u8,
len: usize,
}
#[derive(Debug, Clone)]
// #[repr(C)]
struct CircomBn254 {
@ -158,11 +135,45 @@ pub unsafe extern "C" fn release_circom_compat(ctx_ptr: &mut *mut CircomCompatCt
// Only use if the buffer was allocated by the ffi
pub unsafe extern "C" fn release_buffer(buff_ptr: &mut *mut Buffer) {
if !buff_ptr.is_null() {
let buff = &mut Box::from_raw(*buff_ptr);
let _ = Box::from_raw(buff.data as *mut u8);
buff.data = std::ptr::null_mut();
buff.len = 0;
*buff_ptr = std::ptr::null_mut();
let buff = Box::from_raw(*buff_ptr);
let data = Box::from_raw(slice_from_raw_parts_mut(buff.data as *mut u8, buff.len));
drop(data);
drop(buff);
}
}
#[no_mangle]
pub unsafe extern "C" fn release_proof(proof_ptr: &mut *mut Proof) {
if !proof_ptr.is_null() {
drop(Box::from_raw(*proof_ptr));
*proof_ptr = std::ptr::null_mut();
}
}
#[no_mangle]
// Only use if the buffer was allocated by the ffi
pub unsafe extern "C" fn release_inputs(inputs_ptr: &mut *mut Inputs) {
if !inputs_ptr.is_null() {
let inputs = Box::from_raw(*inputs_ptr);
let elms = Box::from_raw(slice_from_raw_parts_mut(
inputs.elms as *mut [u8; 32],
inputs.len,
));
drop(elms);
drop(inputs);
*inputs_ptr = std::ptr::null_mut();
}
}
#[no_mangle]
// Only use if the buffer was allocated by the ffi
pub unsafe extern "C" fn release_key(key_ptr: &mut *mut VerifyingKey) {
if !key_ptr.is_null() {
let key = Box::from_raw(*key_ptr);
let ic: Box<[G1]> = Box::from_raw(slice_from_raw_parts_mut(key.ic as *mut G1, key.ic_len));
drop(ic);
drop(key);
*key_ptr = std::ptr::null_mut();
}
}
@ -176,19 +187,39 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
#[allow(private_interfaces)]
pub unsafe extern "C" fn prove_circuit(
ctx_ptr: *mut CircomCompatCtx,
compress: bool,
proof_bytes_ptr: &mut *mut Buffer,
inputs_bytes_ptr: &mut *mut Buffer,
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let rng = &mut (*ctx_ptr).rng;
let mode = match compress {
true => Compress::Yes,
false => Compress::No,
};
let circuit = (*circom.builder)
.clone()
.build()
.map_err(|_| ERR_CIRCOM_BUILDER)
.unwrap();
let circom_proof = GrothBn::prove(proving_key, circuit, rng)
.map_err(|_| ERR_MAKING_PROOF)
.unwrap();
*proof_ptr = Box::leak(Box::new((&circom_proof).into()));
}));
to_err_code(result)
}
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn get_pub_inputs(
ctx_ptr: *mut CircomCompatCtx,
inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let circuit = (*circom.builder)
.clone()
.build()
@ -199,40 +230,26 @@ pub unsafe extern "C" fn prove_circuit(
.get_public_inputs()
.ok_or_else(|| ERR_GET_PUB_INPUTS)
.unwrap();
let proof = GrothBn::prove(&proving_key, circuit, rng)
.map_err(|_| ERR_MAKING_PROOF)
.unwrap();
*inputs_ptr = Box::leak(Box::new(inputs.as_slice().into()));
}));
let mut proof_bytes = Vec::new();
proof
.serialize_with_mode(&mut proof_bytes, mode)
.map_err(|_| ERR_SERIALIZE_PROOF)
.unwrap();
to_err_code(result)
}
let mut public_inputs_bytes = Vec::new();
inputs
.serialize_with_mode(&mut public_inputs_bytes, mode)
.map_err(|_| ERR_SERIALIZE_INPUTS)
.unwrap();
/// # Safety
///
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn get_verifying_key(
ctx_ptr: *mut CircomCompatCtx,
vk_ptr: &mut *mut VerifyingKey, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let vk = prepare_verifying_key(&proving_key.vk).vk;
// leak the buffers to avoid rust from freeing the pointed to data,
// clone to avoid bytes from being freed
let proof_slice = Box::leak(Box::new(proof_bytes.clone())).as_slice();
let proof_buff = Buffer {
data: proof_slice.as_ptr() as *const u8,
len: proof_bytes.len(),
};
// leak the buffers to avoid rust from freeing the pointed to data,
// clone to avoid bytes from being freed
let input_slice = Box::leak(Box::new(public_inputs_bytes.clone())).as_slice();
let input_buff = Buffer {
data: input_slice.as_ptr() as *const u8,
len: public_inputs_bytes.len(),
};
*proof_bytes_ptr = Box::into_raw(Box::new(proof_buff));
*inputs_bytes_ptr = Box::into_raw(Box::new(input_buff));
*vk_ptr = Box::leak(Box::new((&vk).into()));
}));
to_err_code(result)
@ -243,36 +260,14 @@ pub unsafe extern "C" fn prove_circuit(
#[no_mangle]
#[allow(private_interfaces)]
pub unsafe extern "C" fn verify_circuit(
ctx_ptr: *mut CircomCompatCtx,
compress: bool,
proof_bytes_ptr: *const Buffer,
inputs_bytes_ptr: *const Buffer,
proof: *const Proof,
inputs: *const Inputs,
pvk: *const VerifyingKey,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let mode = match compress {
true => Compress::Yes,
false => Compress::No,
};
let proof_bytes =
std::slice::from_raw_parts((*proof_bytes_ptr).data, (*proof_bytes_ptr).len);
let proof = Proof::<Bn254>::deserialize_with_mode(proof_bytes, mode, Validate::Yes)
.map_err(|_| ERR_FAILED_TO_DESERIALIZE_PROOF)
.unwrap();
let public_inputs_bytes =
std::slice::from_raw_parts((*inputs_bytes_ptr).data, (*inputs_bytes_ptr).len);
let public_inputs: Vec<Fr> =
CanonicalDeserialize::deserialize_with_mode(public_inputs_bytes, mode, Validate::Yes)
.map_err(|_| ERR_FAILED_TO_DESERIALIZE_INPUTS)
.unwrap();
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let pvk = prepare_verifying_key(&proving_key.vk);
GrothBn::verify_proof(&pvk, &proof, &public_inputs)
let inputs_vec: Vec<Fr> = (*inputs).into();
let prepared_key = prepare_verifying_key(&(*pvk).into());
GrothBn::verify_proof(&prepared_key, &(*proof).into(), inputs_vec.as_slice())
.map_err(|_| ERR_FAILED_TO_VERIFY_PROOF)
.unwrap();
}));
@ -346,9 +341,8 @@ build_fn!(push_input_u64, x: u64);
#[cfg(test)]
mod test {
use std::ffi::CString;
use super::*;
use std::ffi::CString;
#[test]
fn proof_verify() {
@ -372,28 +366,29 @@ mod test {
let b = CString::new("b".as_bytes()).unwrap();
push_input_i8(ctx_ptr, b.as_ptr(), 11);
let mut proof_bytes_ptr: *mut Buffer = std::ptr::null_mut();
let mut inputs_bytes_ptr: *mut Buffer = std::ptr::null_mut();
let mut proof_ptr: *mut Proof = std::ptr::null_mut();
let mut inputs_ptr: *mut Inputs = std::ptr::null_mut();
let mut vk_ptr: *mut VerifyingKey = std::ptr::null_mut();
assert!(prove_circuit(ctx_ptr, true, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK);
assert!(get_pub_inputs(ctx_ptr, &mut inputs_ptr) == ERR_OK);
assert!(inputs_ptr != std::ptr::null_mut());
assert!(proof_bytes_ptr != std::ptr::null_mut());
assert!((*proof_bytes_ptr).data != std::ptr::null());
assert!((*proof_bytes_ptr).len > 0);
assert!(prove_circuit(ctx_ptr, &mut proof_ptr) == ERR_OK);
assert!(proof_ptr != std::ptr::null_mut());
assert!(inputs_bytes_ptr != std::ptr::null_mut());
assert!((*inputs_bytes_ptr).data != std::ptr::null());
assert!((*inputs_bytes_ptr).len > 0);
assert!(get_verifying_key(ctx_ptr, &mut vk_ptr) == ERR_OK);
assert!(vk_ptr != std::ptr::null_mut());
assert!(verify_circuit(ctx_ptr, true, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK);
assert!(verify_circuit(&(*proof_ptr), &(*inputs_ptr), &(*vk_ptr)) == ERR_OK);
release_buffer(&mut proof_bytes_ptr);
release_buffer(&mut inputs_bytes_ptr);
release_circom_compat(&mut ctx_ptr);
release_inputs(&mut inputs_ptr);
assert!(inputs_ptr == std::ptr::null_mut());
assert!(ctx_ptr == std::ptr::null_mut());
assert!(proof_bytes_ptr == std::ptr::null_mut());
assert!(inputs_bytes_ptr == std::ptr::null_mut());
release_proof(&mut proof_ptr);
assert!(proof_ptr == std::ptr::null_mut());
release_key(&mut vk_ptr);
assert!(vk_ptr == std::ptr::null_mut());
};
}
}

308
src/ffi_types.rs Normal file
View File

@ -0,0 +1,308 @@
use std::ptr::slice_from_raw_parts;
use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine};
use ark_ff::{BigInteger, PrimeField};
use ark_serialize::CanonicalDeserialize;
use ark_std::Zero;
use num_bigint::BigUint;
pub const ERR_UNKNOWN: i32 = -1;
pub const ERR_OK: i32 = 0;
pub const ERR_WASM_PATH: i32 = 1;
pub const ERR_R1CS_PATH: i32 = 2;
pub const ERR_ZKEY_PATH: i32 = 3;
pub const ERR_INPUT_NAME: i32 = 4;
pub const ERR_INVALID_INPUT: i32 = 5;
pub const ERR_CANT_READ_ZKEY: i32 = 6;
pub const ERR_CIRCOM_BUILDER: i32 = 7;
pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8;
pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9;
pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
pub const ERR_GET_PUB_INPUTS: i32 = 11;
pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
pub data: *const u8,
pub len: usize,
}
// Helper for converting a PrimeField to little endian byte slice
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");
F::from_bigint(bigint).unwrap()
}
// Helper for converting a PrimeField to its U256 representation for Ethereum compatibility
fn point_to_slice<F: PrimeField>(point: F) -> [u8; 32] {
let point = point.into_bigint();
let point_bytes = point.to_bytes_le();
point_bytes.try_into().expect("always works")
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct G1 {
pub x: [u8; 32],
pub y: [u8; 32],
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct G2 {
pub x: [[u8; 32]; 2],
pub y: [[u8; 32]; 2],
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct Proof {
pub a: G1,
pub b: G2,
pub c: G1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct VerifyingKey {
pub alpha1: G1,
pub beta2: G2,
pub gamma2: G2,
pub delta2: G2,
pub ic: *const G1,
pub ic_len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct Inputs {
pub elms: *const [u8; 32],
pub len: usize,
}
impl From<&G1Affine> for G1 {
fn from(src: &G1Affine) -> Self {
Self {
x: point_to_slice(src.x),
y: point_to_slice(src.y),
}
}
}
impl From<&G2Affine> for G2 {
fn from(src: &G2Affine) -> Self {
// We should use the `.as_tuple()` method which handles converting
// the G2 elements to have the second limb first
Self {
x: [point_to_slice(src.x.c0), point_to_slice(src.x.c1)],
y: [point_to_slice(src.y.c0), point_to_slice(src.y.c1)],
}
}
}
impl From<&ark_groth16::Proof<Bn254>> for Proof {
fn from(src: &ark_groth16::Proof<Bn254>) -> Self {
Self {
a: (&src.a).into(),
b: (&src.b).into(),
c: (&src.c).into(),
}
}
}
impl From<G1> for G1Affine {
fn from(src: G1) -> Self {
let x: Fq = slice_to_point(&src.x);
let y: Fq = slice_to_point(&src.y);
if x.is_zero() && y.is_zero() {
G1Affine::identity()
} else {
G1Affine::new(x, y)
}
}
}
impl From<G2> for G2Affine {
fn from(src: G2) -> G2Affine {
let c0 = slice_to_point(&src.x[0]);
let c1 = slice_to_point(&src.x[1]);
let x = Fq2::new(c0, c1);
let c0 = slice_to_point(&src.y[0]);
let c1 = slice_to_point(&src.y[1]);
let y = Fq2::new(c0, c1);
if x.is_zero() && y.is_zero() {
G2Affine::identity()
} else {
G2Affine::new(x, y)
}
}
}
impl From<Proof> for ark_groth16::Proof<Bn254> {
fn from(src: Proof) -> ark_groth16::Proof<Bn254> {
ark_groth16::Proof {
a: src.a.into(),
b: src.b.into(),
c: src.c.into(),
}
}
}
impl From<VerifyingKey> for ark_groth16::VerifyingKey<Bn254> {
fn from(src: VerifyingKey) -> ark_groth16::VerifyingKey<Bn254> {
ark_groth16::VerifyingKey {
alpha_g1: src.alpha1.into(),
beta_g2: src.beta2.into(),
gamma_g2: src.gamma2.into(),
delta_g2: src.delta2.into(),
gamma_abc_g1: unsafe {
std::slice::from_raw_parts(src.ic, src.ic_len)
.iter()
.map(|p| (*p).into())
.collect()
},
}
}
}
impl From<&ark_groth16::VerifyingKey<Bn254>> for VerifyingKey {
fn from(vk: &ark_groth16::VerifyingKey<Bn254>) -> Self {
let ic: Vec<G1> = vk.gamma_abc_g1.iter().map(|p| p.into()).collect();
let len = ic.len();
Self {
alpha1: G1::from(&vk.alpha_g1),
beta2: G2::from(&vk.beta_g2),
gamma2: G2::from(&vk.gamma_g2),
delta2: G2::from(&vk.delta_g2),
ic: Box::leak(Box::new(ic)).as_slice().as_ptr(),
ic_len: len,
}
}
}
impl From<&[Fr]> for Inputs {
fn from(src: &[Fr]) -> Self {
let els: Vec<[u8; 32]> = src
.iter()
.map(|point| point.0.to_bytes_le().try_into().unwrap())
.collect();
let len = els.len();
Self {
elms: Box::leak(els.into_boxed_slice()).as_ptr(),
len: len,
}
}
}
impl From<Inputs> for Vec<Fr> {
fn from(src: Inputs) -> Self {
let els: Vec<Fr> = unsafe {
(&*slice_from_raw_parts(src.elms, src.len))
.iter()
.map(|point| {
let uint = BigUint::from_bytes_le(point);
Fr::from(uint)
})
.collect()
};
els
}
}
#[cfg(test)]
mod test {
use ark_std::UniformRand;
use super::*;
fn fq() -> Fq {
Fq::from(2)
}
fn fr() -> Fr {
Fr::from(2)
}
fn g1() -> G1Affine {
let rng = &mut ark_std::test_rng();
G1Affine::rand(rng)
}
fn g2() -> G2Affine {
let rng = &mut ark_std::test_rng();
G2Affine::rand(rng)
}
#[test]
fn convert_fq() {
let el = fq();
let el2 = point_to_slice(el);
let el3: Fq = slice_to_point(&el2);
let el4 = point_to_slice(el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_fr() {
let el = fr();
let el2 = point_to_slice(el);
let el3: Fr = slice_to_point(&el2);
let el4 = point_to_slice(el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_g1() {
let el = g1();
let el2 = G1::from(&el);
let el3: G1Affine = el2.into();
let el4 = G1::from(&el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_g2() {
let el = g2();
let el2 = G2::from(&el);
let el3: G2Affine = el2.into();
let el4 = G2::from(&el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_vk() {
let vk = ark_groth16::VerifyingKey::<Bn254> {
alpha_g1: g1(),
beta_g2: g2(),
gamma_g2: g2(),
delta_g2: g2(),
gamma_abc_g1: vec![g1(), g1(), g1()],
};
let vk_ffi = &VerifyingKey::from(&vk);
let ark_vk: ark_groth16::VerifyingKey<Bn254> = (*vk_ffi).into();
assert_eq!(ark_vk, vk);
}
#[test]
fn convert_proof() {
let p = ark_groth16::Proof::<Bn254> {
a: g1(),
b: g2(),
c: g1(),
};
let p2 = Proof::from(&p);
let p3 = ark_groth16::Proof::from(p2);
assert_eq!(p, p3);
}
}

View File

@ -1 +1,2 @@
pub mod ffi;
pub mod ffi_types;