add compress flag

This commit is contained in:
Dmitriy Ryajov 2024-01-23 14:25:55 -06:00
parent a20e7cba06
commit aed402f6ee
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
1 changed files with 32 additions and 19 deletions

View File

@ -10,7 +10,7 @@ use ark_bn254::{Bn254, Fr};
use ark_circom::{read_zkey, CircomBuilder, CircomConfig}; use ark_circom::{read_zkey, CircomBuilder, CircomConfig};
use ark_crypto_primitives::snark::SNARK; use ark_crypto_primitives::snark::SNARK;
use ark_groth16::{prepare_verifying_key, Groth16, Proof, ProvingKey}; use ark_groth16::{prepare_verifying_key, Groth16, Proof, ProvingKey};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use ark_std::rand::{rngs::ThreadRng, thread_rng}; use ark_std::rand::{rngs::ThreadRng, thread_rng};
use ruint::aliases::U256; use ruint::aliases::U256;
@ -32,6 +32,7 @@ pub const ERR_FAILED_TO_VERIFY_PROOF: i32 = 10;
pub const ERR_GET_PUB_INPUTS: i32 = 11; pub const ERR_GET_PUB_INPUTS: i32 = 11;
pub const ERR_MAKING_PROOF: i32 = 12; pub const ERR_MAKING_PROOF: i32 = 12;
pub const ERR_SERIALIZE_PROOF: i32 = 13; pub const ERR_SERIALIZE_PROOF: i32 = 13;
pub const ERR_SERIALIZE_INPUTS: i32 = 14;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[repr(C)] #[repr(C)]
@ -175,14 +176,18 @@ unsafe fn to_circom(ctx_ptr: *mut CircomCompatCtx) -> *mut CircomBn254 {
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn prove_circuit( pub unsafe extern "C" fn prove_circuit(
ctx_ptr: *mut CircomCompatCtx, ctx_ptr: *mut CircomCompatCtx,
compress: bool,
proof_bytes_ptr: &mut *mut Buffer, proof_bytes_ptr: &mut *mut Buffer,
inputs_bytes_ptr: &mut *mut Buffer, inputs_bytes_ptr: &mut *mut Buffer,
) -> i32 { ) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| { let result = catch_unwind(AssertUnwindSafe(|| {
let circom = &mut *to_circom(ctx_ptr); let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key); let proving_key = &(*circom.proving_key);
let rng = &mut (*ctx_ptr).rng; let rng = &mut (*ctx_ptr).rng;
let mode = match compress {
true => Compress::Yes,
false => Compress::No,
};
let circuit = (*circom.builder) let circuit = (*circom.builder)
.clone() .clone()
@ -199,12 +204,15 @@ pub unsafe extern "C" fn prove_circuit(
.unwrap(); .unwrap();
let mut proof_bytes = Vec::new(); let mut proof_bytes = Vec::new();
proof.serialize_compressed(&mut proof_bytes).unwrap(); proof
.serialize_with_mode(&mut proof_bytes, mode)
.map_err(|_| ERR_SERIALIZE_PROOF)
.unwrap();
let mut public_inputs_bytes = Vec::new(); let mut public_inputs_bytes = Vec::new();
inputs inputs
.serialize_compressed(&mut public_inputs_bytes) .serialize_with_mode(&mut public_inputs_bytes, mode)
.map_err(|_| ERR_SERIALIZE_PROOF) .map_err(|_| ERR_SERIALIZE_INPUTS)
.unwrap(); .unwrap();
// leak the buffers to avoid rust from freeing the pointed to data, // leak the buffers to avoid rust from freeing the pointed to data,
@ -236,26 +244,31 @@ pub unsafe extern "C" fn prove_circuit(
#[allow(private_interfaces)] #[allow(private_interfaces)]
pub unsafe extern "C" fn verify_circuit( pub unsafe extern "C" fn verify_circuit(
ctx_ptr: *mut CircomCompatCtx, ctx_ptr: *mut CircomCompatCtx,
compress: bool,
proof_bytes_ptr: *const Buffer, proof_bytes_ptr: *const Buffer,
inputs_bytes_ptr: *const Buffer, inputs_bytes_ptr: *const Buffer,
) -> i32 { ) -> i32 {
let result = catch_unwind(AssertUnwindSafe(|| { let result = catch_unwind(AssertUnwindSafe(|| {
let mode = match compress {
true => Compress::Yes,
false => Compress::No,
};
let proof_bytes = let proof_bytes =
std::slice::from_raw_parts((*proof_bytes_ptr).data, (*proof_bytes_ptr).len); std::slice::from_raw_parts((*proof_bytes_ptr).data, (*proof_bytes_ptr).len);
let proof = Proof::<Bn254>::deserialize_compressed(proof_bytes) let proof = Proof::<Bn254>::deserialize_with_mode(proof_bytes, mode, Validate::Yes)
.map_err(|_| ERR_FAILED_TO_DESERIALIZE_PROOF) .map_err(|_| ERR_FAILED_TO_DESERIALIZE_PROOF)
.unwrap(); .unwrap();
let public_inputs_bytes = let public_inputs_bytes =
std::slice::from_raw_parts((*inputs_bytes_ptr).data, (*inputs_bytes_ptr).len); std::slice::from_raw_parts((*inputs_bytes_ptr).data, (*inputs_bytes_ptr).len);
let public_inputs: Vec<Fr> = let public_inputs: Vec<Fr> =
CanonicalDeserialize::deserialize_compressed(public_inputs_bytes) CanonicalDeserialize::deserialize_with_mode(public_inputs_bytes, mode, Validate::Yes)
.map_err(|_| ERR_FAILED_TO_DESERIALIZE_INPUTS) .map_err(|_| ERR_FAILED_TO_DESERIALIZE_INPUTS)
.unwrap(); .unwrap();
let circom = &mut *to_circom(ctx_ptr); let circom = &mut *to_circom(ctx_ptr);
let proving_key = &(*circom.proving_key); let proving_key = &(*circom.proving_key);
let pvk = prepare_verifying_key(&proving_key.vk); let pvk = prepare_verifying_key(&proving_key.vk);
@ -323,13 +336,13 @@ macro_rules! build_fn
}; };
} }
build_fn!(push_input_numeric_i8, x: i8); build_fn!(push_input_i8, x: i8);
build_fn!(push_input_numeric_u8, x: u8); build_fn!(push_input_u8, x: u8);
build_fn!(push_input_numeric_i16, x: i16); build_fn!(push_input_i16, x: i16);
build_fn!(push_input_numeric_u16, x: u16); build_fn!(push_input_u16, x: u16);
build_fn!(push_input_numeric_i32, x: i32); build_fn!(push_input_i32, x: i32);
build_fn!(push_input_numeric_u32, x: u32); build_fn!(push_input_u32, x: u32);
build_fn!(push_input_numeric_u64, x: u64); build_fn!(push_input_u64, x: u64);
#[cfg(test)] #[cfg(test)]
mod test { mod test {
@ -354,15 +367,15 @@ mod test {
assert!(ctx_ptr != std::ptr::null_mut()); assert!(ctx_ptr != std::ptr::null_mut());
let a = CString::new("a".as_bytes()).unwrap(); let a = CString::new("a".as_bytes()).unwrap();
push_input_numeric_i8(ctx_ptr, a.as_ptr(), 3); push_input_i8(ctx_ptr, a.as_ptr(), 3);
let b = CString::new("b".as_bytes()).unwrap(); let b = CString::new("b".as_bytes()).unwrap();
push_input_numeric_i8(ctx_ptr, b.as_ptr(), 11); push_input_i8(ctx_ptr, b.as_ptr(), 11);
let mut proof_bytes_ptr: *mut Buffer = std::ptr::null_mut(); let mut proof_bytes_ptr: *mut Buffer = std::ptr::null_mut();
let mut inputs_bytes_ptr: *mut Buffer = std::ptr::null_mut(); let mut inputs_bytes_ptr: *mut Buffer = std::ptr::null_mut();
assert!(prove_circuit(ctx_ptr, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK); assert!(prove_circuit(ctx_ptr, true, &mut proof_bytes_ptr, &mut inputs_bytes_ptr) == ERR_OK);
assert!(proof_bytes_ptr != std::ptr::null_mut()); assert!(proof_bytes_ptr != std::ptr::null_mut());
assert!((*proof_bytes_ptr).data != std::ptr::null()); assert!((*proof_bytes_ptr).data != std::ptr::null());
@ -372,7 +385,7 @@ mod test {
assert!((*inputs_bytes_ptr).data != std::ptr::null()); assert!((*inputs_bytes_ptr).data != std::ptr::null());
assert!((*inputs_bytes_ptr).len > 0); assert!((*inputs_bytes_ptr).len > 0);
assert!(verify_circuit(ctx_ptr, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK); assert!(verify_circuit(ctx_ptr, true, &(*proof_bytes_ptr), &(*inputs_bytes_ptr)) == ERR_OK);
release_buffer(&mut proof_bytes_ptr); release_buffer(&mut proof_bytes_ptr);
release_buffer(&mut inputs_bytes_ptr); release_buffer(&mut inputs_bytes_ptr);