From d6bb7a6b309072b98ebc42fe8f07170477541f53 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Thu, 25 Jan 2024 13:05:38 -0600 Subject: [PATCH] rework ffi to export serialized types --- Cargo.toml | 2 + src/ffi.rs | 128 +++++-------------- src/ffi_types.rs | 310 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 4 files changed, 341 insertions(+), 100 deletions(-) create mode 100644 src/ffi_types.rs diff --git a/Cargo.toml b/Cargo.toml index 5b433d8..e570e6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,5 @@ ark-poly = { version = "=0.4.1", default-features = false, features = ["parallel ark-relations = { version = "=0.4.0", default-features = false } ark-serialize = { version = "=0.4.1", default-features = false } ruint = { version = "1.7.0", features = ["serde", "num-bigint", "ark-ff"] } +num-bigint = "0.4.3" + diff --git a/src/ffi.rs b/src/ffi.rs index afc02dc..816de13 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -9,38 +9,15 @@ use std::{ use ark_bn254::{Bn254, Fr}; use ark_circom::{read_zkey, CircomBuilder, CircomConfig}; use ark_crypto_primitives::snark::SNARK; -use ark_groth16::{prepare_verifying_key, Groth16, Proof, ProvingKey}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use ark_groth16::{prepare_verifying_key, Groth16, Proof as Groth16Proof, ProvingKey}; +use ark_serialize::Compress; use ark_std::rand::{rngs::ThreadRng, thread_rng}; - use ruint::aliases::U256; +use crate::ffi_types::*; + type GrothBn = Groth16; -pub const ERR_UNKNOWN: i32 = -1; -pub const ERR_OK: i32 = 0; -pub const ERR_WASM_PATH: i32 = 1; -pub const ERR_R1CS_PATH: i32 = 2; -pub const ERR_ZKEY_PATH: i32 = 3; -pub const ERR_INPUT_NAME: i32 = 4; -pub const ERR_INVALID_INPUT: i32 = 5; -pub const ERR_CANT_READ_ZKEY: i32 = 6; -pub const ERR_CIRCOM_BUILDER: i32 = 7; -pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8; -pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9; -pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10; -pub const ERR_GET_PUB_INPUTS: i32 = 11; -pub const ERR_MAKING_PROOF: i32 = 12; -pub const ERR_SERIALIZE_PROOF: i32 = 13; -pub const ERR_SERIALIZE_INPUTS: i32 = 14; - -#[derive(Debug, Clone)] -#[repr(C)] -pub struct Buffer { - data: *const u8, - len: usize, -} - #[derive(Debug, Clone)] // #[repr(C)] struct CircomBn254 { @@ -176,18 +153,13 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { #[allow(private_interfaces)] pub unsafe extern "C" fn prove_circuit( ctx_ptr: *mut CircomCompatCtx, - compress: bool, - proof_bytes_ptr: &mut *mut Buffer, - inputs_bytes_ptr: &mut *mut Buffer, + proof_ptr: &mut *mut Proof, // inputs_bytes_ptr: &mut *mut Buffer, + inputs_ptr: &mut *mut Inputs, // inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { let circom = &mut *to_circom(ctx_ptr); let proving_key = &(*circom.proving_key); let rng = &mut (*ctx_ptr).rng; - let mode = match compress { - true => Compress::Yes, - false => Compress::No, - }; let circuit = (*circom.builder) .clone() @@ -199,40 +171,13 @@ pub unsafe extern "C" fn prove_circuit( .get_public_inputs() .ok_or_else(|| ERR_GET_PUB_INPUTS) .unwrap(); - let proof = GrothBn::prove(&proving_key, circuit, rng) + + let circomProof = GrothBn::prove(proving_key, circuit, rng) .map_err(|_| ERR_MAKING_PROOF) .unwrap(); - let mut proof_bytes = Vec::new(); - proof - .serialize_with_mode(&mut proof_bytes, mode) - .map_err(|_| ERR_SERIALIZE_PROOF) - .unwrap(); - - let mut public_inputs_bytes = Vec::new(); - inputs - .serialize_with_mode(&mut public_inputs_bytes, mode) - .map_err(|_| ERR_SERIALIZE_INPUTS) - .unwrap(); - - // leak the buffers to avoid rust from freeing the pointed to data, - // clone to avoid bytes from being freed - let proof_slice = Box::leak(Box::new(proof_bytes.clone())).as_slice(); - let proof_buff = Buffer { - data: proof_slice.as_ptr() as *const u8, - len: proof_bytes.len(), - }; - - // leak the buffers to avoid rust from freeing the pointed to data, - // clone to avoid bytes from being freed - let input_slice = Box::leak(Box::new(public_inputs_bytes.clone())).as_slice(); - let input_buff = Buffer { - data: input_slice.as_ptr() as *const u8, - len: public_inputs_bytes.len(), - }; - - *proof_bytes_ptr = Box::into_raw(Box::new(proof_buff)); - *inputs_bytes_ptr = Box::into_raw(Box::new(input_buff)); + *proof_ptr = Box::leak(Box::new((&circomProof).into())); + *inputs_ptr = Box::leak(Box::new(inputs.as_slice().into())); })); to_err_code(result) @@ -245,8 +190,8 @@ pub unsafe extern "C" fn prove_circuit( pub unsafe extern "C" fn verify_circuit( ctx_ptr: *mut CircomCompatCtx, compress: bool, - proof_bytes_ptr: *const Buffer, - inputs_bytes_ptr: *const Buffer, + proof: *const Proof, + inputs: *const Inputs, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { let mode = match compress { @@ -254,25 +199,12 @@ pub unsafe extern "C" fn verify_circuit( false => Compress::No, }; - let proof_bytes = - std::slice::from_raw_parts((*proof_bytes_ptr).data, (*proof_bytes_ptr).len); - - let proof = Proof::::deserialize_with_mode(proof_bytes, mode, Validate::Yes) - .map_err(|_| ERR_FAILED_TO_DESERIALIZE_PROOF) - .unwrap(); - - let public_inputs_bytes = - std::slice::from_raw_parts((*inputs_bytes_ptr).data, (*inputs_bytes_ptr).len); - let public_inputs: Vec = - CanonicalDeserialize::deserialize_with_mode(public_inputs_bytes, mode, Validate::Yes) - .map_err(|_| ERR_FAILED_TO_DESERIALIZE_INPUTS) - .unwrap(); - let circom = &mut *to_circom(ctx_ptr); let proving_key = &(*circom.proving_key); let pvk = prepare_verifying_key(&proving_key.vk); - GrothBn::verify_proof(&pvk, &proof, &public_inputs) + let inputs_vec: Vec = (*inputs).into(); + GrothBn::verify_proof(&pvk, &(*proof).into(), inputs_vec.as_slice()) .map_err(|_| ERR_FAILED_TO_VERIFY_PROOF) .unwrap(); })); @@ -347,7 +279,6 @@ build_fn!(push_input_u64, x: u64); #[cfg(test)] mod test { use std::ffi::CString; - use super::*; #[test] @@ -372,28 +303,25 @@ mod test { let b = CString::new("b".as_bytes()).unwrap(); push_input_i8(ctx_ptr, b.as_ptr(), 11); - let mut proof_bytes_ptr: *mut Buffer = std::ptr::null_mut(); - let mut inputs_bytes_ptr: *mut Buffer = std::ptr::null_mut(); + let mut proof_ptr: *mut Proof = std::ptr::null_mut(); + let mut inputs_ptr: *mut Inputs = std::ptr::null_mut(); - assert!(prove_circuit(ctx_ptr, true, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK); + assert!(prove_circuit(ctx_ptr, &mut proof_ptr, &mut inputs_ptr) == ERR_OK); - assert!(proof_bytes_ptr != std::ptr::null_mut()); - assert!((*proof_bytes_ptr).data != std::ptr::null()); - assert!((*proof_bytes_ptr).len > 0); + assert!(proof_ptr != std::ptr::null_mut()); + assert!(inputs_ptr != std::ptr::null_mut()); - assert!(inputs_bytes_ptr != std::ptr::null_mut()); - assert!((*inputs_bytes_ptr).data != std::ptr::null()); - assert!((*inputs_bytes_ptr).len > 0); + assert!( + verify_circuit(ctx_ptr, true, &(*proof_ptr), &(*inputs_ptr)) == ERR_OK + ); - assert!(verify_circuit(ctx_ptr, true, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK); + // release_buffer(&mut proof_bytes_ptr); + // release_buffer(&mut inputs_bytes_ptr); + // release_circom_compat(&mut ctx_ptr); - release_buffer(&mut proof_bytes_ptr); - release_buffer(&mut inputs_bytes_ptr); - release_circom_compat(&mut ctx_ptr); - - assert!(ctx_ptr == std::ptr::null_mut()); - assert!(proof_bytes_ptr == std::ptr::null_mut()); - assert!(inputs_bytes_ptr == std::ptr::null_mut()); + // assert!(ctx_ptr == std::ptr::null_mut()); + // assert!(proof_bytes_ptr == std::ptr::null_mut()); + // assert!(inputs_bytes_ptr == std::ptr::null_mut()); }; } } diff --git a/src/ffi_types.rs b/src/ffi_types.rs new file mode 100644 index 0000000..787513c --- /dev/null +++ b/src/ffi_types.rs @@ -0,0 +1,310 @@ +use std::ptr::slice_from_raw_parts; + +use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine}; +use ark_ff::{BigInteger, PrimeField}; +use ark_groth16::{Groth16, Proof as Groth16Proof}; +use ark_serialize::CanonicalDeserialize; +use ark_std::Zero; +use num_bigint::BigUint; + +type GrothBn = Groth16; + +pub const ERR_UNKNOWN: i32 = -1; +pub const ERR_OK: i32 = 0; +pub const ERR_WASM_PATH: i32 = 1; +pub const ERR_R1CS_PATH: i32 = 2; +pub const ERR_ZKEY_PATH: i32 = 3; +pub const ERR_INPUT_NAME: i32 = 4; +pub const ERR_INVALID_INPUT: i32 = 5; +pub const ERR_CANT_READ_ZKEY: i32 = 6; +pub const ERR_CIRCOM_BUILDER: i32 = 7; +pub const ERR_FAILED_TO_DESERIALIZE_PROOF: i32 = 8; +pub const ERR_FAILED_TO_DESERIALIZE_INPUTS: i32 = 9; +pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10; +pub const ERR_GET_PUB_INPUTS: i32 = 11; +pub const ERR_MAKING_PROOF: i32 = 12; +pub const ERR_SERIALIZE_PROOF: i32 = 13; +pub const ERR_SERIALIZE_INPUTS: i32 = 14; + +#[derive(Debug, Clone)] +#[repr(C)] +pub struct Buffer { + pub data: *const u8, + pub len: usize, +} + +// Helper for converting a PrimeField to little endian byte slice +fn slice_to_point(point: &[u8; 32]) -> F { + let bigint = F::BigInt::deserialize_uncompressed(&point[..]).expect("always works"); + F::from_bigint(bigint).unwrap() +} + +// Helper for converting a PrimeField to its U256 representation for Ethereum compatibility +fn point_to_slice(point: F) -> [u8; 32] { + let point = point.into_bigint(); + let point_bytes = point.to_bytes_le(); + point_bytes.try_into().expect("always works") +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] +pub struct G1 { + pub x: [u8; 32], + pub y: [u8; 32], +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] +pub struct G2 { + pub x: [[u8; 32]; 2], + pub y: [[u8; 32]; 2], +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] +pub struct Proof { + pub a: G1, + pub b: G2, + pub c: G1, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] +pub struct VerifyingKey { + pub alpha1: G1, + pub beta2: G2, + pub gamma2: G2, + pub delta2: G2, + pub ic: *const G1, + pub ic_len: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] +pub struct Inputs { + pub elms: *const [u8; 32], + pub len: usize, +} + +impl From<&G1Affine> for G1 { + fn from(src: &G1Affine) -> Self { + Self { + x: point_to_slice(src.x), + y: point_to_slice(src.y), + } + } +} + +impl From<&G2Affine> for G2 { + fn from(src: &G2Affine) -> Self { + // We should use the `.as_tuple()` method which handles converting + // the G2 elements to have the second limb first + Self { + x: [point_to_slice(src.x.c0), point_to_slice(src.x.c1)], + y: [point_to_slice(src.y.c0), point_to_slice(src.y.c1)], + } + } +} + +impl From<&Groth16Proof> for Proof { + fn from(src: &Groth16Proof) -> Self { + Self { + a: (&src.a).into(), + b: (&src.b).into(), + c: (&src.c).into(), + } + } +} + +impl From for G1Affine { + fn from(src: G1) -> Self { + let x: Fq = slice_to_point(&src.x); + let y: Fq = slice_to_point(&src.y); + if x.is_zero() && y.is_zero() { + G1Affine::identity() + } else { + G1Affine::new(x, y) + } + } +} + +impl From for G2Affine { + fn from(src: G2) -> G2Affine { + let c0 = slice_to_point(&src.x[0]); + let c1 = slice_to_point(&src.x[1]); + let x = Fq2::new(c0, c1); + + let c0 = slice_to_point(&src.y[0]); + let c1 = slice_to_point(&src.y[1]); + let y = Fq2::new(c0, c1); + + if x.is_zero() && y.is_zero() { + G2Affine::identity() + } else { + G2Affine::new(x, y) + } + } +} + +impl From for ark_groth16::Proof { + fn from(src: Proof) -> ark_groth16::Proof { + ark_groth16::Proof { + a: src.a.into(), + b: src.b.into(), + c: src.c.into(), + } + } +} + +impl From for ark_groth16::VerifyingKey { + fn from(src: VerifyingKey) -> ark_groth16::VerifyingKey { + ark_groth16::VerifyingKey { + alpha_g1: src.alpha1.into(), + beta_g2: src.beta2.into(), + gamma_g2: src.gamma2.into(), + delta_g2: src.delta2.into(), + gamma_abc_g1: unsafe { + std::slice::from_raw_parts(src.ic, src.ic_len) + .iter() + .map(|p| (*p).into()) + .collect() + }, + } + } +} + +impl From<&ark_groth16::VerifyingKey> for VerifyingKey { + fn from(vk: &ark_groth16::VerifyingKey) -> Self { + let ic: Vec = vk.gamma_abc_g1.iter().map(|p| p.into()).collect(); + let len = ic.len(); + Self { + alpha1: G1::from(&vk.alpha_g1), + beta2: G2::from(&vk.beta_g2), + gamma2: G2::from(&vk.gamma_g2), + delta2: G2::from(&vk.delta_g2), + ic: Box::leak(Box::new(ic)).as_slice().as_ptr(), + ic_len: len, + } + } +} + +impl From<&[Fr]> for Inputs { + fn from(src: &[Fr]) -> Self { + let els: Vec<[u8; 32]> = src + .iter() + .map(|point| point.0.to_bytes_le().try_into().unwrap()) + .collect(); + + let len = els.len(); + Self { + elms: Box::leak(els.into_boxed_slice()).as_ptr(), + len: len, + } + } +} + +impl From for Vec { + fn from(src: Inputs) -> Self { + let els: Vec = unsafe { + (&*slice_from_raw_parts(src.elms, src.len)) + .iter() + .map(|point| { + let uint = BigUint::from_bytes_le(point); + Fr::from(uint) + }) + .collect() + }; + + els + } +} + +mod test { + use ark_std::UniformRand; + + use super::*; + + fn fq() -> Fq { + Fq::from(2) + } + + fn fr() -> Fr { + Fr::from(2) + } + + fn g1() -> G1Affine { + let rng = &mut ark_std::test_rng(); + G1Affine::rand(rng) + } + + fn g2() -> G2Affine { + let rng = &mut ark_std::test_rng(); + G2Affine::rand(rng) + } + + #[test] + fn convert_fq() { + let el = fq(); + let el2 = point_to_slice(el); + let el3: Fq = slice_to_point(&el2); + let el4 = point_to_slice(el3); + assert_eq!(el, el3); + assert_eq!(el2, el4); + } + + #[test] + fn convert_fr() { + let el = fr(); + let el2 = point_to_slice(el); + let el3: Fr = slice_to_point(&el2); + let el4 = point_to_slice(el3); + assert_eq!(el, el3); + assert_eq!(el2, el4); + } + + #[test] + fn convert_g1() { + let el = g1(); + let el2 = G1::from(&el); + let el3: G1Affine = el2.into(); + let el4 = G1::from(&el3); + assert_eq!(el, el3); + assert_eq!(el2, el4); + } + + #[test] + fn convert_g2() { + let el = g2(); + let el2 = G2::from(&el); + let el3: G2Affine = el2.into(); + let el4 = G2::from(&el3); + assert_eq!(el, el3); + assert_eq!(el2, el4); + } + + #[test] + fn convert_vk() { + let vk = ark_groth16::VerifyingKey:: { + alpha_g1: g1(), + beta_g2: g2(), + gamma_g2: g2(), + delta_g2: g2(), + gamma_abc_g1: vec![g1(), g1(), g1()], + }; + let vk_ffi = &VerifyingKey::from(&vk); + let ark_vk: ark_groth16::VerifyingKey = (*vk_ffi).into(); + assert_eq!(ark_vk, vk); + } + + #[test] + fn convert_proof() { + let p = ark_groth16::Proof:: { + a: g1(), + b: g2(), + c: g1(), + }; + let p2 = Proof::from(&p); + let p3 = ark_groth16::Proof::from(p2); + assert_eq!(p, p3); + } +} diff --git a/src/lib.rs b/src/lib.rs index d8a989e..81badff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,2 @@ pub mod ffi; +pub mod ffi_types;