From aed402f6ee8a2dc225ec2db2eb36d1888aab7790 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Tue, 23 Jan 2024 14:25:55 -0600 Subject: [PATCH] add compress flag --- src/ffi.rs | 51 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 680ee45..afc02dc 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -10,7 +10,7 @@ use ark_bn254::{Bn254, Fr}; use ark_circom::{read_zkey, CircomBuilder, CircomConfig}; use ark_crypto_primitives::snark::SNARK; use ark_groth16::{prepare_verifying_key, Groth16, Proof, ProvingKey}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use ark_std::rand::{rngs::ThreadRng, thread_rng}; use ruint::aliases::U256; @@ -32,6 +32,7 @@ pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10; pub const ERR_GET_PUB_INPUTS: i32 = 11; pub const ERR_MAKING_PROOF: i32 = 12; pub const ERR_SERIALIZE_PROOF: i32 = 13; +pub const ERR_SERIALIZE_INPUTS: i32 = 14; #[derive(Debug, Clone)] #[repr(C)] @@ -175,14 +176,18 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 { #[allow(private_interfaces)] pub unsafe extern "C" fn prove_circuit( ctx_ptr: *mut CircomCompatCtx, + compress: bool, proof_bytes_ptr: &mut *mut Buffer, inputs_bytes_ptr: &mut *mut Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { let circom = &mut *to_circom(ctx_ptr); - let proving_key = &(*circom.proving_key); let rng = &mut (*ctx_ptr).rng; + let mode = match compress { + true => Compress::Yes, + false => Compress::No, + }; let circuit = (*circom.builder) .clone() @@ -199,12 +204,15 @@ pub unsafe extern "C" fn prove_circuit( .unwrap(); let mut proof_bytes = Vec::new(); - proof.serialize_compressed(&mut proof_bytes).unwrap(); + proof + .serialize_with_mode(&mut proof_bytes, mode) + .map_err(|_| ERR_SERIALIZE_PROOF) + .unwrap(); let mut public_inputs_bytes = Vec::new(); inputs - .serialize_compressed(&mut public_inputs_bytes) - .map_err(|_| ERR_SERIALIZE_PROOF) + .serialize_with_mode(&mut public_inputs_bytes, mode) + .map_err(|_| ERR_SERIALIZE_INPUTS) .unwrap(); // leak the buffers to avoid rust from freeing the pointed to data, @@ -236,26 +244,31 @@ pub unsafe extern "C" fn prove_circuit( #[allow(private_interfaces)] pub unsafe extern "C" fn verify_circuit( ctx_ptr: *mut CircomCompatCtx, + compress: bool, proof_bytes_ptr: *const Buffer, inputs_bytes_ptr: *const Buffer, ) -> i32 { let result = catch_unwind(AssertUnwindSafe(|| { + let mode = match compress { + true => Compress::Yes, + false => Compress::No, + }; + let proof_bytes = std::slice::from_raw_parts((*proof_bytes_ptr).data, (*proof_bytes_ptr).len); - let proof = Proof::::deserialize_compressed(proof_bytes) + let proof = Proof::::deserialize_with_mode(proof_bytes, mode, Validate::Yes) .map_err(|_| ERR_FAILED_TO_DESERIALIZE_PROOF) .unwrap(); let public_inputs_bytes = std::slice::from_raw_parts((*inputs_bytes_ptr).data, (*inputs_bytes_ptr).len); let public_inputs: Vec = - CanonicalDeserialize::deserialize_compressed(public_inputs_bytes) + CanonicalDeserialize::deserialize_with_mode(public_inputs_bytes, mode, Validate::Yes) .map_err(|_| ERR_FAILED_TO_DESERIALIZE_INPUTS) .unwrap(); let circom = &mut *to_circom(ctx_ptr); - let proving_key = &(*circom.proving_key); let pvk = prepare_verifying_key(&proving_key.vk); @@ -323,13 +336,13 @@ macro_rules! build_fn }; } -build_fn!(push_input_numeric_i8, x: i8); -build_fn!(push_input_numeric_u8, x: u8); -build_fn!(push_input_numeric_i16, x: i16); -build_fn!(push_input_numeric_u16, x: u16); -build_fn!(push_input_numeric_i32, x: i32); -build_fn!(push_input_numeric_u32, x: u32); -build_fn!(push_input_numeric_u64, x: u64); +build_fn!(push_input_i8, x: i8); +build_fn!(push_input_u8, x: u8); +build_fn!(push_input_i16, x: i16); +build_fn!(push_input_u16, x: u16); +build_fn!(push_input_i32, x: i32); +build_fn!(push_input_u32, x: u32); +build_fn!(push_input_u64, x: u64); #[cfg(test)] mod test { @@ -354,15 +367,15 @@ mod test { assert!(ctx_ptr != std::ptr::null_mut()); let a = CString::new("a".as_bytes()).unwrap(); - push_input_numeric_i8(ctx_ptr, a.as_ptr(), 3); + push_input_i8(ctx_ptr, a.as_ptr(), 3); let b = CString::new("b".as_bytes()).unwrap(); - push_input_numeric_i8(ctx_ptr, b.as_ptr(), 11); + push_input_i8(ctx_ptr, b.as_ptr(), 11); let mut proof_bytes_ptr: *mut Buffer = std::ptr::null_mut(); let mut inputs_bytes_ptr: *mut Buffer = std::ptr::null_mut(); - assert!(prove_circuit(ctx_ptr, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK); + assert!(prove_circuit(ctx_ptr, true, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK); assert!(proof_bytes_ptr != std::ptr::null_mut()); assert!((*proof_bytes_ptr).data != std::ptr::null()); @@ -372,7 +385,7 @@ mod test { assert!((*inputs_bytes_ptr).data != std::ptr::null()); assert!((*inputs_bytes_ptr).len > 0); - assert!(verify_circuit(ctx_ptr, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK); + assert!(verify_circuit(ctx_ptr, true, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK); release_buffer(&mut proof_bytes_ptr); release_buffer(&mut inputs_bytes_ptr);