Merge pull request #357 from logos-blockchain/marvin/issue_295

[Issue 295] ensure! in public transaction
This commit is contained in:
jonesmarvin8 2026-03-16 09:49:52 -04:00 committed by GitHub
commit d665540495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 37 deletions

View File

@ -2,6 +2,15 @@ use std::io;
use thiserror::Error;
#[macro_export]
macro_rules! ensure {
($cond:expr, $err:expr) => {
if !$cond {
return Err($err);
}
};
}
#[derive(Error, Debug)]
pub enum NssaError {
#[error("Invalid input: {0}")]
@ -58,3 +67,24 @@ pub enum NssaError {
#[error("Chain of calls is too long")]
MaxChainedCallsDepthExceeded,
}
#[cfg(test)]
mod tests {
#[derive(Debug)]
enum TestError {
TestErr,
}
fn test_function_ensure(cond: bool) -> Result<(), TestError> {
ensure!(cond, TestError::TestErr);
Ok(())
}
#[test]
fn test_ensure() {
assert!(test_function_ensure(true).is_ok());
assert!(test_function_ensure(false).is_err());
}
}

View File

@ -9,7 +9,7 @@ use nssa_core::{
use sha2::{Digest, digest::FixedOutput};
use crate::{
V02State,
V02State, ensure,
error::NssaError,
public_transaction::{Message, WitnessSet},
state::MAX_NUMBER_CHAINED_CALLS,
@ -70,33 +70,33 @@ impl PublicTransaction {
let witness_set = self.witness_set();
// All account_ids must be different
if message.account_ids.iter().collect::<HashSet<_>>().len() != message.account_ids.len() {
return Err(NssaError::InvalidInput(
"Duplicate account_ids found in message".into(),
));
}
ensure!(
message.account_ids.iter().collect::<HashSet<_>>().len() == message.account_ids.len(),
NssaError::InvalidInput("Duplicate account_ids found in message".into(),)
);
// Check exactly one nonce is provided for each signature
if message.nonces.len() != witness_set.signatures_and_public_keys.len() {
return Err(NssaError::InvalidInput(
ensure!(
message.nonces.len() == witness_set.signatures_and_public_keys.len(),
NssaError::InvalidInput(
"Mismatch between number of nonces and signatures/public keys".into(),
));
}
)
);
// Check the signatures are valid
if !witness_set.is_valid_for(message) {
return Err(NssaError::InvalidInput(
"Invalid signature for given message and public key".into(),
));
}
ensure!(
witness_set.is_valid_for(message),
NssaError::InvalidInput("Invalid signature for given message and public key".into())
);
let signer_account_ids = self.signer_account_ids();
// Check nonces corresponds to the current nonces on the public state.
for (account_id, nonce) in signer_account_ids.iter().zip(&message.nonces) {
let current_nonce = state.get_account_by_id(*account_id).nonce;
if current_nonce != *nonce {
return Err(NssaError::InvalidInput("Nonce mismatch".into()));
}
ensure!(
current_nonce == *nonce,
NssaError::InvalidInput("Nonce mismatch".into())
);
}
// Build pre_states for execution
@ -125,9 +125,10 @@ impl PublicTransaction {
let mut chain_calls_counter = 0;
while let Some((chained_call, caller_program_id)) = chained_calls.pop_front() {
if chain_calls_counter > MAX_NUMBER_CHAINED_CALLS {
return Err(NssaError::MaxChainedCallsDepthExceeded);
}
ensure!(
chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS,
NssaError::MaxChainedCallsDepthExceeded
);
// Check that the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else {
@ -158,28 +159,31 @@ impl PublicTransaction {
.get(&account_id)
.cloned()
.unwrap_or_else(|| state.get_account_by_id(account_id));
if pre.account != expected_pre {
return Err(NssaError::InvalidProgramBehavior);
}
ensure!(
pre.account == expected_pre,
NssaError::InvalidProgramBehavior
);
// Check that authorization flags are consistent with the provided ones or
// authorized by program through the PDA mechanism
let is_authorized = signer_account_ids.contains(&account_id)
|| authorized_pdas.contains(&account_id);
if pre.is_authorized != is_authorized {
return Err(NssaError::InvalidProgramBehavior);
}
ensure!(
pre.is_authorized == is_authorized,
NssaError::InvalidProgramBehavior
);
}
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
if !validate_execution(
&program_output.pre_states,
&program_output.post_states,
chained_call.program_id,
) {
return Err(NssaError::InvalidProgramBehavior);
}
ensure!(
validate_execution(
&program_output.pre_states,
&program_output.post_states,
chained_call.program_id,
),
NssaError::InvalidProgramBehavior
);
for post in program_output
.post_states
@ -221,9 +225,10 @@ impl PublicTransaction {
}
Some(post)
}) {
if post.program_owner == DEFAULT_PROGRAM_ID {
return Err(NssaError::InvalidProgramBehavior);
}
ensure!(
post.program_owner != DEFAULT_PROGRAM_ID,
NssaError::InvalidProgramBehavior
);
}
Ok(state_diff)