diff --git a/src/ffi.rs b/src/ffi.rs index 007a65f..7f3f4ef 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -1,17 +1,14 @@ -use crate::public::RLN; +use crate::{circuit::rln, public::RLN}; use bellman::pairing::bn256::Bn256; use std::slice; /// Buffer struct is taken from /// https://github.com/celo-org/celo-threshold-bls-rs/blob/master/crates/threshold-bls-ffi/src/ffi.rs -/// Data structure which is used to store buffers of varying length #[repr(C)] #[derive(Clone, Debug, PartialEq)] pub struct Buffer { - /// Pointer to the message pub ptr: *const u8, - /// The length of the buffer pub len: usize, } @@ -102,6 +99,19 @@ pub unsafe extern "C" fn hash( true } +#[no_mangle] +pub unsafe extern "C" fn key_gen(ctx: *const RLN, keypair_buffer: *mut Buffer) -> bool { + let rln = unsafe { &*ctx }; + let mut output_data: Vec = Vec::new(); + match rln.key_gen(&mut output_data) { + Ok(_) => (), + Err(_) => return false, + } + unsafe { *keypair_buffer = Buffer::from(&output_data[..]) }; + std::mem::forget(output_data); + true +} + use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr}; use sapling_crypto::bellman::pairing::Engine; use std::io::{self, Read, Write}; @@ -227,4 +237,35 @@ mod tests { let result_data = <&[u8]>::from(&result_buffer); assert_eq!(expected_data.as_slice(), result_data); } + + #[test] + fn test_keygen_ffi() { + let rln_test = rln_test(); + + let mut circuit_parameters: Vec = Vec::new(); + rln_test + .export_circuit_parameters(&mut circuit_parameters) + .unwrap(); + let mut hasher = rln_test.hasher(); + + let rln_pointer = rln_pointer(circuit_parameters); + let rln_pointer = unsafe { &*rln_pointer.assume_init() }; + + let mut keypair_buffer = MaybeUninit::::uninit(); + + let success = unsafe { key_gen(rln_pointer, keypair_buffer.as_mut_ptr()) }; + assert!(success, "proof generation failed"); + + let keypair_buffer = unsafe { keypair_buffer.assume_init() }; + let mut keypair_data = <&[u8]>::from(&keypair_buffer); + + let mut buf = ::Repr::default(); + buf.read_le(&mut keypair_data).unwrap(); + let secret = Fr::from_repr(buf).unwrap(); + buf.read_le(&mut keypair_data).unwrap(); + let public = Fr::from_repr(buf).unwrap(); + let expected_public: Fr = hasher.hash(vec![secret]); + + assert_eq!(public, expected_public); + } } diff --git a/src/public.rs b/src/public.rs index 050d9dd..5474562 100644 --- a/src/public.rs +++ b/src/public.rs @@ -125,6 +125,16 @@ where Ok(success) } + pub fn key_gen(&self, mut w: W) -> io::Result<()> { + let mut rng = XorShiftRng::from_seed([0x3dbe6258, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut hasher = self.hasher(); + let secret = E::Fr::rand(&mut rng); + let public: E::Fr = hasher.hash(vec![secret.clone()]); + secret.into_repr().write_le(&mut w)?; + public.into_repr().write_le(&mut w)?; + Ok(()) + } + pub fn export_verifier_key(&self, w: W) -> io::Result<()> { self.circuit_parameters.vk.write(w) }