diff --git a/README.md b/README.md index cc07a2e..bc08fff 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,16 @@ let zkey_path = "./test-vectors/multiplier2_final.zkey"; let proof = rust_rapidsnark::groth16_prover_zkey_file_wrapper(zkey_path, wtns_buffer).unwrap(); ``` +You can also prove from an in-memory zkey buffer: + +```rust +let zkey_buffer = std::fs::read("./test-vectors/multiplier2_final.zkey")?; +let proof = rust_rapidsnark::groth16_prover_zkey_buffer_wrapper(&zkey_buffer, &wtns_buffer)?; +``` + ### Verify the proof -Verify the proof by using the `groth16_verifier_zkey_file_wrapper` function. +Verify the proof by using the `groth16_verify_wrapper` function. ```rust let vkey = std::fs::read_to_string("./test-vectors/keccak256_256_test.vkey.json")?; diff --git a/crates/README.md b/crates/README.md index cc07a2e..bc08fff 100644 --- a/crates/README.md +++ b/crates/README.md @@ -49,9 +49,16 @@ let zkey_path = "./test-vectors/multiplier2_final.zkey"; let proof = rust_rapidsnark::groth16_prover_zkey_file_wrapper(zkey_path, wtns_buffer).unwrap(); ``` +You can also prove from an in-memory zkey buffer: + +```rust +let zkey_buffer = std::fs::read("./test-vectors/multiplier2_final.zkey")?; +let proof = rust_rapidsnark::groth16_prover_zkey_buffer_wrapper(&zkey_buffer, &wtns_buffer)?; +``` + ### Verify the proof -Verify the proof by using the `groth16_verifier_zkey_file_wrapper` function. +Verify the proof by using the `groth16_verify_wrapper` function. ```rust let vkey = std::fs::read_to_string("./test-vectors/keccak256_256_test.vkey.json")?; diff --git a/crates/src/lib.rs b/crates/src/lib.rs index efed7b6..7c497b2 100644 --- a/crates/src/lib.rs +++ b/crates/src/lib.rs @@ -17,6 +17,7 @@ //! use std::collections::HashMap; +use std::ffi::{c_char, c_ulonglong, c_void}; use std::str::FromStr; use anyhow::Result; @@ -33,17 +34,43 @@ pub struct ProofResult { pub public_signals: String, } +const PROVER_OK: i32 = 0; +const PROVER_ERROR_SHORT_BUFFER: i32 = 2; + extern "C" { + pub fn groth16_public_size_for_zkey_buf( + zkey_buffer: *const c_void, + zkey_size: c_ulonglong, + public_size: *mut c_ulonglong, + error_msg: *mut c_char, + error_msg_maxsize: c_ulonglong, + ) -> i32; + + pub fn groth16_proof_size(proof_size: *mut c_ulonglong); + + pub fn groth16_prover( + zkey_buffer: *const c_void, + zkey_size: c_ulonglong, + wtns_buffer: *const c_void, + wtns_size: c_ulonglong, + proof_buffer: *mut c_char, + proof_size: *mut c_ulonglong, + public_buffer: *mut c_char, + public_size: *mut c_ulonglong, + error_msg: *mut c_char, + error_msg_maxsize: c_ulonglong, + ) -> i32; + pub fn groth16_prover_zkey_file( zkey_file_path: *const std::os::raw::c_char, wtns_buffer: *const std::os::raw::c_void, - wtns_size: std::ffi::c_ulong, + wtns_size: c_ulonglong, proof_buffer: *mut std::os::raw::c_char, - proof_size: *mut std::ffi::c_ulong, + proof_size: *mut c_ulonglong, public_buffer: *mut std::os::raw::c_char, - public_size: *mut std::ffi::c_ulong, + public_size: *mut c_ulonglong, error_msg: *mut std::os::raw::c_char, - error_msg_maxsize: std::ffi::c_ulong, + error_msg_maxsize: c_ulonglong, ) -> i32; pub fn groth16_verify( @@ -110,6 +137,37 @@ pub fn parse_bigints_to_witness(bigints: Vec) -> io::Result> { Ok(buffer) } +fn error_message(error_msg: &[u8]) -> String { + let end = error_msg + .iter() + .position(|byte| *byte == 0) + .unwrap_or(error_msg.len()); + String::from_utf8_lossy(&error_msg[..end]).into_owned() +} + +fn buffer_len(size: c_ulonglong, name: &str) -> Result { + usize::try_from(size).map_err(|_| anyhow::anyhow!("{name} buffer size is too large")) +} + +fn buffer_len_with_null(size: c_ulonglong, name: &str) -> Result { + let size = size + .checked_add(1) + .ok_or_else(|| anyhow::anyhow!("{name} buffer size overflowed"))?; + buffer_len(size, name) +} + +fn output_string(name: &str, buffer: &[u8], size: c_ulonglong) -> Result { + let size = buffer_len(size, name)?; + if size > buffer.len() { + return Err(anyhow::anyhow!( + "{name} output size {size} exceeds buffer size {}", + buffer.len() + )); + } + String::from_utf8(buffer[..size].to_vec()) + .map_err(|err| anyhow::anyhow!("{name} output is not valid UTF-8: {err}")) +} + /// Wrapper for `groth16_prover_zkey_file` pub fn groth16_prover_zkey_file_wrapper( zkey_path: &str, @@ -161,6 +219,105 @@ pub fn groth16_prover_zkey_file_wrapper( } } +/// Wrapper for `groth16_prover`, which proves from an in-memory zkey buffer. +pub fn groth16_prover_zkey_buffer_wrapper( + zkey_buffer: &[u8], + wtns_buffer: &[u8], +) -> Result { + let zkey_size = zkey_buffer.len() as c_ulonglong; + let wtns_size = wtns_buffer.len() as c_ulonglong; + + let mut error_msg = vec![0u8; 256]; + let error_msg_ptr = error_msg.as_mut_ptr() as *mut c_char; + + let mut proof_size: c_ulonglong = 0; + let mut public_size: c_ulonglong = 0; + + unsafe { + groth16_proof_size(&mut proof_size); + + let result = groth16_public_size_for_zkey_buf( + zkey_buffer.as_ptr() as *const c_void, + zkey_size, + &mut public_size, + error_msg_ptr, + error_msg.len() as c_ulonglong, + ); + if result != PROVER_OK { + return Err(anyhow::anyhow!( + "Public signal buffer size calculation failed: {}", + error_message(&error_msg) + )); + } + } + + let mut proof_buffer = vec![0u8; buffer_len_with_null(proof_size, "proof")?]; + let mut public_buffer = vec![0u8; buffer_len_with_null(public_size, "public signals")?]; + + let mut proof_output_size = proof_buffer.len() as c_ulonglong; + let mut public_output_size = public_buffer.len() as c_ulonglong; + + let prove_once = |proof_buffer: &mut [u8], + proof_output_size: &mut c_ulonglong, + public_buffer: &mut [u8], + public_output_size: &mut c_ulonglong, + error_msg: &mut [u8]| { + error_msg.fill(0); + unsafe { + groth16_prover( + zkey_buffer.as_ptr() as *const c_void, + zkey_size, + wtns_buffer.as_ptr() as *const c_void, + wtns_size, + proof_buffer.as_mut_ptr() as *mut c_char, + proof_output_size, + public_buffer.as_mut_ptr() as *mut c_char, + public_output_size, + error_msg.as_mut_ptr() as *mut c_char, + error_msg.len() as c_ulonglong, + ) + } + }; + + let mut result = prove_once( + &mut proof_buffer, + &mut proof_output_size, + &mut public_buffer, + &mut public_output_size, + &mut error_msg, + ); + + if result == PROVER_ERROR_SHORT_BUFFER { + proof_buffer.resize(buffer_len_with_null(proof_output_size, "proof")?, 0); + public_buffer.resize( + buffer_len_with_null(public_output_size, "public signals")?, + 0, + ); + proof_output_size = proof_buffer.len() as c_ulonglong; + public_output_size = public_buffer.len() as c_ulonglong; + + result = prove_once( + &mut proof_buffer, + &mut proof_output_size, + &mut public_buffer, + &mut public_output_size, + &mut error_msg, + ); + } + + if result != PROVER_OK { + return Err(anyhow::anyhow!( + "Proof generation failed: {}", + error_message(&error_msg) + )); + } + + Ok(ProofResult { + proof: output_string("proof", &proof_buffer, proof_output_size)?, + public_signals: output_string("public signals", &public_buffer, public_output_size)?, + }) +} + /// Wrapper for `groth16_verify` pub fn groth16_verify_wrapper(proof: &str, inputs: &str, verification_key: &str) -> Result { let proof_cstr = std::ffi::CString::new(proof).unwrap(); diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 7e8def6..4c01042 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -84,6 +84,42 @@ mod tests { Ok(()) } + #[test] + fn test_prove_rapidsnark_zkey_buffer() -> Result<()> { + let zkey_path = "./test-vectors/multiplier2_final.zkey"; + let zkey_buffer = std::fs::read(zkey_path)?; + + let mut inputs = HashMap::new(); + inputs.insert("a".to_string(), vec!["3".to_string()]); + inputs.insert("b".to_string(), vec!["11".to_string()]); + + let wtns_buffer = compute_witness(inputs, multiplier2_witness)?; + let path_proof_result = + rust_rapidsnark::groth16_prover_zkey_file_wrapper(zkey_path, wtns_buffer.clone())?; + let buffer_proof_result = + rust_rapidsnark::groth16_prover_zkey_buffer_wrapper(&zkey_buffer, &wtns_buffer)?; + + assert_eq!( + path_proof_result.public_signals, + buffer_proof_result.public_signals + ); + + let vkey = std::fs::read_to_string("./test-vectors/multiplier2.vkey.json")?; + let path_valid = rust_rapidsnark::groth16_verify_wrapper( + &path_proof_result.proof, + &path_proof_result.public_signals, + &vkey, + )?; + let buffer_valid = rust_rapidsnark::groth16_verify_wrapper( + &buffer_proof_result.proof, + &buffer_proof_result.public_signals, + &vkey, + )?; + assert!(path_valid); + assert!(buffer_valid); + Ok(()) + } + #[test] fn test_prove_rapidsnark_keccak() -> Result<()> { // Create a new MoproCircom instance