diff --git a/src/circuit_tests/utils.rs b/src/circuit_tests/utils.rs index 4899b51..b989f4e 100644 --- a/src/circuit_tests/utils.rs +++ b/src/circuit_tests/utils.rs @@ -10,7 +10,7 @@ pub fn digest(input: &[U256], chunk_size: Option) -> U256 { let range = (i * chunk_size)..std::cmp::min((i + 1) * chunk_size, input.len()); let mut chunk = input[range].to_vec(); if chunk.len() < chunk_size { - chunk.resize(chunk_size as usize, uint!(0_U256)); + chunk.resize(chunk_size, uint!(0_U256)); } concat.push(hash(chunk.as_slice())); @@ -20,7 +20,7 @@ pub fn digest(input: &[U256], chunk_size: Option) -> U256 { return hash(concat.as_slice()); } - return concat[0]; + concat[0] } pub fn merkelize(leafs: &[U256]) -> U256 { @@ -43,5 +43,5 @@ pub fn merkelize(leafs: &[U256]) -> U256 { merkle = new_merkle; } - return merkle[0]; + merkle[0] } diff --git a/src/ffi.rs b/src/ffi.rs index dd5e3e5..c7aa242 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -32,13 +32,16 @@ impl ProofCtx { } } +/// # Safety +/// +/// Construct a StorageProofs object #[no_mangle] -pub extern "C" fn init( +pub unsafe extern "C" fn init( r1cs: *const &Buffer, wasm: *const &Buffer, zkey: *const &Buffer, ) -> *mut StorageProofs { - let r1cs = unsafe { + let r1cs = { if r1cs.is_null() { return std::ptr::null_mut(); } @@ -47,8 +50,8 @@ pub extern "C" fn init( str::from_utf8(slice).unwrap().to_string() }; - let wasm = unsafe { - if wasm == std::ptr::null() { + let wasm = { + if wasm.is_null() { return std::ptr::null_mut(); } @@ -56,7 +59,7 @@ pub extern "C" fn init( str::from_utf8(slice).unwrap().to_string() }; - let zkey = unsafe { + let zkey = { if !zkey.is_null() { let slice = std::slice::from_raw_parts((*zkey).data, (*zkey).len); Some(str::from_utf8(slice).unwrap().to_string()) @@ -68,8 +71,12 @@ pub extern "C" fn init( Box::into_raw(Box::new(StorageProofs::new(wasm, r1cs, zkey))) } + +/// # Safety +/// +/// Use after constructing a StorageProofs object #[no_mangle] -pub extern "C" fn prove( +pub unsafe extern "C" fn prove( prover_ptr: *mut StorageProofs, chunks: *const Buffer, siblings: *const Buffer, @@ -80,7 +87,7 @@ pub extern "C" fn prove( root: *const Buffer, salt: *const Buffer, ) -> *mut ProofCtx { - let chunks = unsafe { + let chunks = { let slice = std::slice::from_raw_parts((*chunks).data, (*chunks).len); slice .chunks(U256::BYTES) @@ -88,7 +95,7 @@ pub extern "C" fn prove( .collect::>() }; - let siblings = unsafe { + let siblings = { let slice = std::slice::from_raw_parts((*siblings).data, (*siblings).len); slice .chunks(U256::BYTES) @@ -96,7 +103,7 @@ pub extern "C" fn prove( .collect::>() }; - let hashes = unsafe { + let hashes = { let slice = std::slice::from_raw_parts((*hashes).data, (*hashes).len); slice .chunks(U256::BYTES) @@ -104,27 +111,24 @@ pub extern "C" fn prove( .collect::>() }; - let path = unsafe { + let path = { let slice = std::slice::from_raw_parts(path, path_len); slice.to_vec() }; - let pubkey = unsafe { - U256::try_from_le_slice(std::slice::from_raw_parts((*pubkey).data, (*pubkey).len)).unwrap() - }; + let pubkey = + U256::try_from_le_slice(std::slice::from_raw_parts((*pubkey).data, (*pubkey).len)).unwrap(); - let root = unsafe { - U256::try_from_le_slice(std::slice::from_raw_parts((*root).data, (*root).len)).unwrap() - }; + let root = + U256::try_from_le_slice(std::slice::from_raw_parts((*root).data, (*root).len)).unwrap(); - let salt = unsafe { - U256::try_from_le_slice(std::slice::from_raw_parts((*salt).data, (*salt).len)).unwrap() - }; + let salt = + U256::try_from_le_slice(std::slice::from_raw_parts((*salt).data, (*salt).len)).unwrap(); let proof_bytes = &mut Vec::new(); let public_inputs_bytes = &mut Vec::new(); - let mut _prover = unsafe { &mut *prover_ptr }; + let mut _prover = &mut *prover_ptr; _prover .prove( chunks.as_slice(), @@ -142,22 +146,25 @@ pub extern "C" fn prove( } #[no_mangle] -pub extern "C" fn verify( +/// # Safety +/// +/// Should be called on a valid proof and public inputs previously generated by prove +pub unsafe extern "C" fn verify( prover_ptr: *mut StorageProofs, proof: *const Buffer, public_inputs: *const Buffer, ) -> bool { - let proof = unsafe { std::slice::from_raw_parts((*proof).data, (*proof).len) }; - - let public_inputs = - unsafe { std::slice::from_raw_parts((*public_inputs).data, (*public_inputs).len) }; - - let mut _prover = unsafe { &mut *prover_ptr }; + let proof = std::slice::from_raw_parts((*proof).data, (*proof).len); + let public_inputs = std::slice::from_raw_parts((*public_inputs).data, (*public_inputs).len); + let mut _prover = &mut *prover_ptr; _prover.verify(proof, public_inputs).is_ok() } +/// # Safety +/// +/// Use on a valid pointer to StorageProofs or panics #[no_mangle] -pub extern "C" fn free_prover(prover: *mut StorageProofs) { +pub unsafe extern "C" fn free_prover(prover: *mut StorageProofs) { if prover.is_null() { return; } @@ -165,17 +172,21 @@ pub extern "C" fn free_prover(prover: *mut StorageProofs) { unsafe { drop(Box::from_raw(prover)) } } -pub extern "C" fn free_proof_ctx(ctx: *mut ProofCtx) { +/// # Safety +/// +/// Use on a valid pointer to ProofCtx or panics +#[no_mangle] +pub unsafe extern "C" fn free_proof_ctx(ctx: *mut ProofCtx) { if ctx.is_null() { return; } - unsafe { drop(Box::from_raw(ctx)) } + drop(Box::from_raw(ctx)) } #[cfg(test)] mod tests { - use ark_std::rand::{rngs::ThreadRng, distributions::Alphanumeric, Rng}; + use ark_std::rand::{distributions::Alphanumeric, rngs::ThreadRng, Rng}; use ruint::aliases::U256; use crate::{ diff --git a/src/poseidon/constants.rs b/src/poseidon/constants.rs index 4ea9465..813a336 100644 --- a/src/poseidon/constants.rs +++ b/src/poseidon/constants.rs @@ -34,12 +34,8 @@ pub static C_CONST: Lazy>> = Lazy::new(|| { }) .collect::, _>>() .unwrap() - .try_into() - .unwrap() }) .collect::>>() - .try_into() - .unwrap() }); pub static S_CONST: Lazy>> = Lazy::new(|| { @@ -62,12 +58,8 @@ pub static S_CONST: Lazy>> = Lazy::new(|| { }) .collect::, _>>() .unwrap() - .try_into() - .unwrap() }) .collect::>>() - .try_into() - .unwrap() }); pub static M_CONST: Lazy>>> = Lazy::new(|| { @@ -94,14 +86,10 @@ pub static M_CONST: Lazy>>> = Lazy::new(|| { }) .collect::, _>>() .unwrap() - .try_into() - .unwrap() }) .collect() }) .collect::>>>() - .try_into() - .unwrap() }); pub static P_CONST: Lazy>>> = Lazy::new(|| { @@ -128,12 +116,8 @@ pub static P_CONST: Lazy>>> = Lazy::new(|| { }) .collect::, _>>() .unwrap() - .try_into() - .unwrap() }) .collect() }) .collect::>>>() - .try_into() - .unwrap() }); diff --git a/src/poseidon/mod.rs b/src/poseidon/mod.rs index 1cb741b..90e027c 100644 --- a/src/poseidon/mod.rs +++ b/src/poseidon/mod.rs @@ -16,7 +16,7 @@ const N_ROUNDS_P: [i32; 16] = [ // Panics if `input` is not a valid field element. #[must_use] pub fn hash(inputs: &[U256]) -> U256 { - assert!(inputs.len() > 0); + assert!(!inputs.is_empty()); assert!(inputs.len() <= N_ROUNDS_P.len()); let t = inputs.len() + 1; @@ -35,7 +35,7 @@ pub fn hash(inputs: &[U256]) -> U256 { for r in 0..(n_rounds_f / 2 - 1) { state = state .iter() - .map(|a| a.pow(&[5])) + .map(|a| a.pow([5])) .enumerate() .map(|(i, a)| a + c[(r + 1) * t + i]) .collect(); @@ -57,7 +57,7 @@ pub fn hash(inputs: &[U256]) -> U256 { state = state .iter() - .map(|a| a.pow(&[5])) + .map(|a| a.pow([5])) .enumerate() .map(|(i, a)| a + c[(n_rounds_f / 2 - 1 + 1) * t + i]) .collect(); @@ -76,9 +76,9 @@ pub fn hash(inputs: &[U256]) -> U256 { }) .collect(); - for r in 0..n_rounds_p as usize { - state[0] = state[0].pow(&[5]); - state[0] = state[0] + c[(n_rounds_f / 2 + 1) * t + r]; + for r in 0..n_rounds_p { + state[0] = state[0].pow([5]); + state[0] += c[(n_rounds_f / 2 + 1) * t + r]; let s0 = state .iter() @@ -94,10 +94,10 @@ pub fn hash(inputs: &[U256]) -> U256 { state[0] = s0; } - for r in 0..(n_rounds_f / 2 - 1) as usize { + for r in 0..(n_rounds_f / 2 - 1) { state = state .iter() - .map(|a| a.pow(&[5])) + .map(|a| a.pow([5])) .enumerate() .map(|(i, a)| a + c[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i]) .collect(); @@ -117,7 +117,7 @@ pub fn hash(inputs: &[U256]) -> U256 { .collect(); } - state = state.iter().map(|a| a.pow(&[5])).collect(); + state = state.iter().map(|a| a.pow([5])).collect(); state = state .iter() .enumerate() diff --git a/src/storage_proofs.rs b/src/storage_proofs.rs index fc10dae..c68f810 100644 --- a/src/storage_proofs.rs +++ b/src/storage_proofs.rs @@ -92,7 +92,7 @@ impl StorageProofs { let proof = Proof::::deserialize(proof_bytes).map_err(|e| e.to_string())?; let vk = prepare_verifying_key(&self.params.vk); - verify_proof(&vk, &proof, &inputs.as_slice()).map_err(|e| e.to_string())?; + verify_proof(&vk, &proof, inputs.as_slice()).map_err(|e| e.to_string())?; Ok(()) }