rework ffi to export serialized types

This commit is contained in:
Dmitriy Ryajov 2024-01-25 13:05:38 -06:00
parent aed402f6ee
commit d6bb7a6b30
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
4 changed files with 341 additions and 100 deletions

View File

@ -25,3 +25,5 @@ ark-poly = { version = "=0.4.1", default-features = false, features = ["parallel
ark-relations = { version = "=0.4.0", default-features = false }
ark-serialize = { version = "=0.4.1", default-features = false }
ruint = { version = "1.7.0", features = ["serde", "num-bigint", "ark-ff"] }
num-bigint = "0.4.3"

View File

@ -9,38 +9,15 @@ use std::{
use ark_bn254::{Bn254, Fr};
use ark_circom::{read_zkey, CircomBuilder, CircomConfig};
use ark_crypto_primitives::snark::SNARK;
use ark_groth16::{prepare_verifying_key, Groth16, Proof, ProvingKey};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use ark_groth16::{prepare_verifying_key, Groth16, Proof as Groth16Proof, ProvingKey};
use ark_serialize::Compress;
use ark_std::rand::{rngs::ThreadRng, thread_rng};
use ruint::aliases::U256;
use crate::ffi_types::*;
type GrothBn = Groth16<Bn254>;
pub const ERR_UNKNOWN: i32 = -1;
pub const ERR_OK: i32 = 0;
pub const ERR_WASM_PATH: i32 = 1;
pub const ERR_R1CS_PATH: i32 = 2;
pub const ERR_ZKEY_PATH: i32 = 3;
pub const ERR_INPUT_NAME: i32 = 4;
pub const ERR_INVALID_INPUT: i32 = 5;
pub const ERR_CANT_READ_ZKEY: i32 = 6;
pub const ERR_CIRCOM_BUILDER: i32 = 7;
pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8;
pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9;
pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
pub const ERR_GET_PUB_INPUTS: i32 = 11;
pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
data: *const u8,
len: usize,
}
#[derive(Debug, Clone)]
// #[repr(C)]
struct CircomBn254 {
@ -176,18 +153,13 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
#[allow(private_interfaces)]
pub unsafe extern "C" fn prove_circuit(
ctx_ptr: *mut CircomCompatCtx,
compress: bool,
proof_bytes_ptr: &mut *mut Buffer,
inputs_bytes_ptr: &mut *mut Buffer,
proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer,
inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let rng = &mut (*ctx_ptr).rng;
let mode = match compress {
true => Compress::Yes,
false => Compress::No,
};
let circuit = (*circom.builder)
.clone()
@ -199,40 +171,13 @@ pub unsafe extern "C" fn prove_circuit(
.get_public_inputs()
.ok_or_else(|| ERR_GET_PUB_INPUTS)
.unwrap();
let proof = GrothBn::prove(&proving_key, circuit, rng)
let circomProof = GrothBn::prove(proving_key, circuit, rng)
.map_err(|_| ERR_MAKING_PROOF)
.unwrap();
let mut proof_bytes = Vec::new();
proof
.serialize_with_mode(&mut proof_bytes, mode)
.map_err(|_| ERR_SERIALIZE_PROOF)
.unwrap();
let mut public_inputs_bytes = Vec::new();
inputs
.serialize_with_mode(&mut public_inputs_bytes, mode)
.map_err(|_| ERR_SERIALIZE_INPUTS)
.unwrap();
// leak the buffers to avoid rust from freeing the pointed to data,
// clone to avoid bytes from being freed
let proof_slice = Box::leak(Box::new(proof_bytes.clone())).as_slice();
let proof_buff = Buffer {
data: proof_slice.as_ptr() as *const u8,
len: proof_bytes.len(),
};
// leak the buffers to avoid rust from freeing the pointed to data,
// clone to avoid bytes from being freed
let input_slice = Box::leak(Box::new(public_inputs_bytes.clone())).as_slice();
let input_buff = Buffer {
data: input_slice.as_ptr() as *const u8,
len: public_inputs_bytes.len(),
};
*proof_bytes_ptr = Box::into_raw(Box::new(proof_buff));
*inputs_bytes_ptr = Box::into_raw(Box::new(input_buff));
*proof_ptr = Box::leak(Box::new((&circomProof).into()));
*inputs_ptr = Box::leak(Box::new(inputs.as_slice().into()));
}));
to_err_code(result)
@ -245,8 +190,8 @@ pub unsafe extern "C" fn prove_circuit(
pub unsafe extern "C" fn verify_circuit(
ctx_ptr: *mut CircomCompatCtx,
compress: bool,
proof_bytes_ptr: *const Buffer,
inputs_bytes_ptr: *const Buffer,
proof: *const Proof,
inputs: *const Inputs,
) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| {
let mode = match compress {
@ -254,25 +199,12 @@ pub unsafe extern "C" fn verify_circuit(
false => Compress::No,
};
let proof_bytes =
std::slice::from_raw_parts((*proof_bytes_ptr).data, (*proof_bytes_ptr).len);
let proof = Proof::<Bn254>::deserialize_with_mode(proof_bytes, mode, Validate::Yes)
.map_err(|_| ERR_FAILED_TO_DESERIALIZE_PROOF)
.unwrap();
let public_inputs_bytes =
std::slice::from_raw_parts((*inputs_bytes_ptr).data, (*inputs_bytes_ptr).len);
let public_inputs: Vec<Fr> =
CanonicalDeserialize::deserialize_with_mode(public_inputs_bytes, mode, Validate::Yes)
.map_err(|_| ERR_FAILED_TO_DESERIALIZE_INPUTS)
.unwrap();
let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key);
let pvk = prepare_verifying_key(&proving_key.vk);
GrothBn::verify_proof(&pvk, &proof, &public_inputs)
let inputs_vec: Vec<Fr> = (*inputs).into();
GrothBn::verify_proof(&pvk, &(*proof).into(), inputs_vec.as_slice())
.map_err(|_| ERR_FAILED_TO_VERIFY_PROOF)
.unwrap();
}));
@ -347,7 +279,6 @@ build_fn!(push_input_u64, x: u64);
#[cfg(test)]
mod test {
use std::ffi::CString;
use super::*;
#[test]
@ -372,28 +303,25 @@ mod test {
let b = CString::new("b".as_bytes()).unwrap();
push_input_i8(ctx_ptr, b.as_ptr(), 11);
let mut proof_bytes_ptr: *mut Buffer = std::ptr::null_mut();
let mut inputs_bytes_ptr: *mut Buffer = std::ptr::null_mut();
let mut proof_ptr: *mut Proof = std::ptr::null_mut();
let mut inputs_ptr: *mut Inputs = std::ptr::null_mut();
assert!(prove_circuit(ctx_ptr, true, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK);
assert!(prove_circuit(ctx_ptr, &mut proof_ptr, &mut inputs_ptr) == ERR_OK);
assert!(proof_bytes_ptr != std::ptr::null_mut());
assert!((*proof_bytes_ptr).data != std::ptr::null());
assert!((*proof_bytes_ptr).len > 0);
assert!(proof_ptr != std::ptr::null_mut());
assert!(inputs_ptr != std::ptr::null_mut());
assert!(inputs_bytes_ptr != std::ptr::null_mut());
assert!((*inputs_bytes_ptr).data != std::ptr::null());
assert!((*inputs_bytes_ptr).len > 0);
assert!(
verify_circuit(ctx_ptr, true, &(*proof_ptr), &(*inputs_ptr)) == ERR_OK
);
assert!(verify_circuit(ctx_ptr, true, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK);
// release_buffer(&mut proof_bytes_ptr);
// release_buffer(&mut inputs_bytes_ptr);
// release_circom_compat(&mut ctx_ptr);
release_buffer(&mut proof_bytes_ptr);
release_buffer(&mut inputs_bytes_ptr);
release_circom_compat(&mut ctx_ptr);
assert!(ctx_ptr == std::ptr::null_mut());
assert!(proof_bytes_ptr == std::ptr::null_mut());
assert!(inputs_bytes_ptr == std::ptr::null_mut());
// assert!(ctx_ptr == std::ptr::null_mut());
// assert!(proof_bytes_ptr == std::ptr::null_mut());
// assert!(inputs_bytes_ptr == std::ptr::null_mut());
};
}
}

310
src/ffi_types.rs Normal file
View File

@ -0,0 +1,310 @@
use std::ptr::slice_from_raw_parts;
use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine};
use ark_ff::{BigInteger, PrimeField};
use ark_groth16::{Groth16, Proof as Groth16Proof};
use ark_serialize::CanonicalDeserialize;
use ark_std::Zero;
use num_bigint::BigUint;
type GrothBn = Groth16<Bn254>;
pub const ERR_UNKNOWN: i32 = -1;
pub const ERR_OK: i32 = 0;
pub const ERR_WASM_PATH: i32 = 1;
pub const ERR_R1CS_PATH: i32 = 2;
pub const ERR_ZKEY_PATH: i32 = 3;
pub const ERR_INPUT_NAME: i32 = 4;
pub const ERR_INVALID_INPUT: i32 = 5;
pub const ERR_CANT_READ_ZKEY: i32 = 6;
pub const ERR_CIRCOM_BUILDER: i32 = 7;
pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8;
pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9;
pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
pub const ERR_GET_PUB_INPUTS: i32 = 11;
pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
pub data: *const u8,
pub len: usize,
}
// Helper for converting a PrimeField to little endian byte slice
fn slice_to_point<F: PrimeField>(point: &[u8; 32]) -> F {
let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works");
F::from_bigint(bigint).unwrap()
}
// Helper for converting a PrimeField to its U256 representation for Ethereum compatibility
fn point_to_slice<F: PrimeField>(point: F) -> [u8; 32] {
let point = point.into_bigint();
let point_bytes = point.to_bytes_le();
point_bytes.try_into().expect("always works")
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct G1 {
pub x: [u8; 32],
pub y: [u8; 32],
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct G2 {
pub x: [[u8; 32]; 2],
pub y: [[u8; 32]; 2],
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct Proof {
pub a: G1,
pub b: G2,
pub c: G1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct VerifyingKey {
pub alpha1: G1,
pub beta2: G2,
pub gamma2: G2,
pub delta2: G2,
pub ic: *const G1,
pub ic_len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct Inputs {
pub elms: *const [u8; 32],
pub len: usize,
}
impl From<&G1Affine> for G1 {
fn from(src: &G1Affine) -> Self {
Self {
x: point_to_slice(src.x),
y: point_to_slice(src.y),
}
}
}
impl From<&G2Affine> for G2 {
fn from(src: &G2Affine) -> Self {
// We should use the `.as_tuple()` method which handles converting
// the G2 elements to have the second limb first
Self {
x: [point_to_slice(src.x.c0), point_to_slice(src.x.c1)],
y: [point_to_slice(src.y.c0), point_to_slice(src.y.c1)],
}
}
}
impl From<&Groth16Proof<Bn254>> for Proof {
fn from(src: &Groth16Proof<Bn254>) -> Self {
Self {
a: (&src.a).into(),
b: (&src.b).into(),
c: (&src.c).into(),
}
}
}
impl From<G1> for G1Affine {
fn from(src: G1) -> Self {
let x: Fq = slice_to_point(&src.x);
let y: Fq = slice_to_point(&src.y);
if x.is_zero() && y.is_zero() {
G1Affine::identity()
} else {
G1Affine::new(x, y)
}
}
}
impl From<G2> for G2Affine {
fn from(src: G2) -> G2Affine {
let c0 = slice_to_point(&src.x[0]);
let c1 = slice_to_point(&src.x[1]);
let x = Fq2::new(c0, c1);
let c0 = slice_to_point(&src.y[0]);
let c1 = slice_to_point(&src.y[1]);
let y = Fq2::new(c0, c1);
if x.is_zero() && y.is_zero() {
G2Affine::identity()
} else {
G2Affine::new(x, y)
}
}
}
impl From<Proof> for ark_groth16::Proof<Bn254> {
fn from(src: Proof) -> ark_groth16::Proof<Bn254> {
ark_groth16::Proof {
a: src.a.into(),
b: src.b.into(),
c: src.c.into(),
}
}
}
impl From<VerifyingKey> for ark_groth16::VerifyingKey<Bn254> {
fn from(src: VerifyingKey) -> ark_groth16::VerifyingKey<Bn254> {
ark_groth16::VerifyingKey {
alpha_g1: src.alpha1.into(),
beta_g2: src.beta2.into(),
gamma_g2: src.gamma2.into(),
delta_g2: src.delta2.into(),
gamma_abc_g1: unsafe {
std::slice::from_raw_parts(src.ic, src.ic_len)
.iter()
.map(|p| (*p).into())
.collect()
},
}
}
}
impl From<&ark_groth16::VerifyingKey<Bn254>> for VerifyingKey {
fn from(vk: &ark_groth16::VerifyingKey<Bn254>) -> Self {
let ic: Vec<G1> = vk.gamma_abc_g1.iter().map(|p| p.into()).collect();
let len = ic.len();
Self {
alpha1: G1::from(&vk.alpha_g1),
beta2: G2::from(&vk.beta_g2),
gamma2: G2::from(&vk.gamma_g2),
delta2: G2::from(&vk.delta_g2),
ic: Box::leak(Box::new(ic)).as_slice().as_ptr(),
ic_len: len,
}
}
}
impl From<&[Fr]> for Inputs {
fn from(src: &[Fr]) -> Self {
let els: Vec<[u8; 32]> = src
.iter()
.map(|point| point.0.to_bytes_le().try_into().unwrap())
.collect();
let len = els.len();
Self {
elms: Box::leak(els.into_boxed_slice()).as_ptr(),
len: len,
}
}
}
impl From<Inputs> for Vec<Fr> {
fn from(src: Inputs) -> Self {
let els: Vec<Fr> = unsafe {
(&*slice_from_raw_parts(src.elms, src.len))
.iter()
.map(|point| {
let uint = BigUint::from_bytes_le(point);
Fr::from(uint)
})
.collect()
};
els
}
}
mod test {
use ark_std::UniformRand;
use super::*;
fn fq() -> Fq {
Fq::from(2)
}
fn fr() -> Fr {
Fr::from(2)
}
fn g1() -> G1Affine {
let rng = &mut ark_std::test_rng();
G1Affine::rand(rng)
}
fn g2() -> G2Affine {
let rng = &mut ark_std::test_rng();
G2Affine::rand(rng)
}
#[test]
fn convert_fq() {
let el = fq();
let el2 = point_to_slice(el);
let el3: Fq = slice_to_point(&el2);
let el4 = point_to_slice(el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_fr() {
let el = fr();
let el2 = point_to_slice(el);
let el3: Fr = slice_to_point(&el2);
let el4 = point_to_slice(el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_g1() {
let el = g1();
let el2 = G1::from(&el);
let el3: G1Affine = el2.into();
let el4 = G1::from(&el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_g2() {
let el = g2();
let el2 = G2::from(&el);
let el3: G2Affine = el2.into();
let el4 = G2::from(&el3);
assert_eq!(el, el3);
assert_eq!(el2, el4);
}
#[test]
fn convert_vk() {
let vk = ark_groth16::VerifyingKey::<Bn254> {
alpha_g1: g1(),
beta_g2: g2(),
gamma_g2: g2(),
delta_g2: g2(),
gamma_abc_g1: vec![g1(), g1(), g1()],
};
let vk_ffi = &VerifyingKey::from(&vk);
let ark_vk: ark_groth16::VerifyingKey<Bn254> = (*vk_ffi).into();
assert_eq!(ark_vk, vk);
}
#[test]
fn convert_proof() {
let p = ark_groth16::Proof::<Bn254> {
a: g1(),
b: g2(),
c: g1(),
};
let p2 = Proof::from(&p);
let p3 = ark_groth16::Proof::from(p2);
assert_eq!(p, p3);
}
}

View File

@ -1 +1,2 @@
pub mod ffi;
pub mod ffi_types;