use std::collections::{HashMap, HashSet, VecDeque}; use borsh::{BorshDeserialize, BorshSerialize}; use log::debug; use nssa_core::{ BlockId, Timestamp, account::{Account, AccountId, AccountWithMetadata}, program::{ChainedCall, Claim, DEFAULT_PROGRAM_ID, validate_execution}, }; use sha2::{Digest as _, digest::FixedOutput as _}; use crate::{ V03State, ensure, error::NssaError, public_transaction::{Message, WitnessSet}, state::MAX_NUMBER_CHAINED_CALLS, }; #[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] pub struct PublicTransaction { pub message: Message, pub witness_set: WitnessSet, } impl PublicTransaction { #[must_use] pub const fn new(message: Message, witness_set: WitnessSet) -> Self { Self { message, witness_set, } } #[must_use] pub const fn message(&self) -> &Message { &self.message } #[must_use] pub const fn witness_set(&self) -> &WitnessSet { &self.witness_set } pub(crate) fn signer_account_ids(&self) -> Vec { self.witness_set .signatures_and_public_keys() .iter() .map(|(_, public_key)| AccountId::from(public_key)) .collect() } #[must_use] pub fn affected_public_account_ids(&self) -> Vec { let mut acc_set = self .signer_account_ids() .into_iter() .collect::>(); acc_set.extend(&self.message.account_ids); acc_set.into_iter().collect() } #[must_use] pub fn hash(&self) -> [u8; 32] { let bytes = self.to_bytes(); let mut hasher = sha2::Sha256::new(); hasher.update(&bytes); hasher.finalize_fixed().into() } pub(crate) fn validate_and_produce_public_state_diff( &self, state: &V03State, block_id: BlockId, timestamp: Timestamp, ) -> Result, NssaError> { let message = self.message(); let witness_set = self.witness_set(); // All account_ids must be different ensure!( message.account_ids.iter().collect::>().len() == message.account_ids.len(), NssaError::InvalidInput("Duplicate account_ids found in message".into(),) ); // Check exactly one nonce is provided for each signature ensure!( message.nonces.len() == witness_set.signatures_and_public_keys.len(), NssaError::InvalidInput( "Mismatch between number of nonces and signatures/public keys".into(), ) ); // Check the signatures are valid ensure!( witness_set.is_valid_for(message), NssaError::InvalidInput("Invalid signature for given message and public key".into()) ); let signer_account_ids = self.signer_account_ids(); // Check nonces corresponds to the current nonces on the public state. for (account_id, nonce) in signer_account_ids.iter().zip(&message.nonces) { let current_nonce = state.get_account_by_id(*account_id).nonce; ensure!( current_nonce == *nonce, NssaError::InvalidInput("Nonce mismatch".into()) ); } // Build pre_states for execution let input_pre_states: Vec<_> = message .account_ids .iter() .map(|account_id| { AccountWithMetadata::new( state.get_account_by_id(*account_id), signer_account_ids.contains(account_id), *account_id, ) }) .collect(); let mut state_diff: HashMap = HashMap::new(); let initial_call = ChainedCall { program_id: message.program_id, instruction_data: message.instruction_data.clone(), pre_states: input_pre_states, pda_seeds: vec![], }; let mut chained_calls = VecDeque::from_iter([(initial_call, None)]); let mut chain_calls_counter = 0; while let Some((chained_call, caller_program_id)) = chained_calls.pop_front() { ensure!( chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS, NssaError::MaxChainedCallsDepthExceeded ); // Check that the `program_id` corresponds to a deployed program let Some(program) = state.programs().get(&chained_call.program_id) else { return Err(NssaError::InvalidInput("Unknown program".into())); }; debug!( "Program {:?} pre_states: {:?}, instruction_data: {:?}", chained_call.program_id, chained_call.pre_states, chained_call.instruction_data ); let mut program_output = program.execute(&chained_call.pre_states, &chained_call.instruction_data)?; debug!( "Program {:?} output: {:?}", chained_call.program_id, program_output ); let authorized_pdas = nssa_core::program::compute_authorized_pdas( caller_program_id, &chained_call.pda_seeds, ); let is_authorized = |account_id: &AccountId| { signer_account_ids.contains(account_id) || authorized_pdas.contains(account_id) }; for pre in &program_output.pre_states { let account_id = pre.account_id; // Check that the program output pre_states coincide with the values in the public // state or with any modifications to those values during the chain of calls. let expected_pre = state_diff .get(&account_id) .cloned() .unwrap_or_else(|| state.get_account_by_id(account_id)); ensure!( pre.account == expected_pre, NssaError::InvalidProgramBehavior ); // Check that authorization flags are consistent with the provided ones or // authorized by program through the PDA mechanism ensure!( pre.is_authorized == is_authorized(&account_id), NssaError::InvalidProgramBehavior ); } // Verify that the program output's self_program_id matches the expected program ID. ensure!( program_output.self_program_id == chained_call.program_id, NssaError::InvalidProgramBehavior ); // Verify execution corresponds to a well-behaved program. // See the # Programs section for the definition of the `validate_execution` method. ensure!( validate_execution( &program_output.pre_states, &program_output.post_states, chained_call.program_id, ), NssaError::InvalidProgramBehavior ); // Verify validity window ensure!( program_output.block_validity_window.is_valid_for(block_id) && program_output .timestamp_validity_window .is_valid_for(timestamp), NssaError::OutOfValidityWindow ); for (i, post) in program_output.post_states.iter_mut().enumerate() { let Some(claim) = post.required_claim() else { continue; }; // The invoked program can only claim accounts with default program id. ensure!( post.account().program_owner == DEFAULT_PROGRAM_ID, NssaError::InvalidProgramBehavior ); let account_id = program_output.pre_states[i].account_id; match claim { Claim::Authorized => { // The program can only claim accounts that were authorized by the signer. ensure!( is_authorized(&account_id), NssaError::InvalidProgramBehavior ); } Claim::Pda(seed) => { // The program can only claim accounts that correspond to the PDAs it is // authorized to claim. let pda = AccountId::from((&chained_call.program_id, &seed)); ensure!(account_id == pda, NssaError::InvalidProgramBehavior); } } post.account_mut().program_owner = chained_call.program_id; } // Update the state diff for (pre, post) in program_output .pre_states .iter() .zip(program_output.post_states.iter()) { state_diff.insert(pre.account_id, post.account().clone()); } for new_call in program_output.chained_calls.into_iter().rev() { chained_calls.push_front((new_call, Some(chained_call.program_id))); } chain_calls_counter = chain_calls_counter .checked_add(1) .expect("we check the max depth at the beginning of the loop"); } // Check that all modified uninitialized accounts where claimed for post in state_diff.iter().filter_map(|(account_id, post)| { let pre = state.get_account_by_id(*account_id); if pre.program_owner != DEFAULT_PROGRAM_ID { return None; } if pre == *post { return None; } Some(post) }) { ensure!( post.program_owner != DEFAULT_PROGRAM_ID, NssaError::InvalidProgramBehavior ); } Ok(state_diff) } } #[cfg(test)] pub mod tests { use sha2::{Digest as _, digest::FixedOutput as _}; use crate::{ AccountId, PrivateKey, PublicKey, PublicTransaction, Signature, V03State, error::NssaError, program::Program, public_transaction::{Message, WitnessSet}, }; fn keys_for_tests() -> (PrivateKey, PrivateKey, AccountId, AccountId) { let key1 = PrivateKey::try_new([1; 32]).unwrap(); let key2 = PrivateKey::try_new([2; 32]).unwrap(); let addr1 = AccountId::from(&PublicKey::new_from_private_key(&key1)); let addr2 = AccountId::from(&PublicKey::new_from_private_key(&key2)); (key1, key2, addr1, addr2) } fn state_for_tests() -> V03State { let (_, _, addr1, addr2) = keys_for_tests(); let initial_data = [(addr1, 10000), (addr2, 20000)]; V03State::new_with_genesis_accounts(&initial_data, &[]) } fn transaction_for_tests() -> PublicTransaction { let (key1, key2, addr1, addr2) = keys_for_tests(); let nonces = vec![0_u128.into(), 0_u128.into()]; let instruction = 1337; let message = Message::try_new( Program::authenticated_transfer_program().id(), vec![addr1, addr2], nonces, instruction, ) .unwrap(); let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]); PublicTransaction::new(message, witness_set) } #[test] fn new_constructor() { let tx = transaction_for_tests(); let message = tx.message().clone(); let witness_set = tx.witness_set().clone(); let tx_from_constructor = PublicTransaction::new(message.clone(), witness_set.clone()); assert_eq!(tx_from_constructor.message, message); assert_eq!(tx_from_constructor.witness_set, witness_set); } #[test] fn message_getter() { let tx = transaction_for_tests(); assert_eq!(&tx.message, tx.message()); } #[test] fn witness_set_getter() { let tx = transaction_for_tests(); assert_eq!(&tx.witness_set, tx.witness_set()); } #[test] fn signer_account_ids() { let tx = transaction_for_tests(); let expected_signer_account_ids = vec![ AccountId::new([ 148, 179, 206, 253, 199, 51, 82, 86, 232, 2, 152, 122, 80, 243, 54, 207, 237, 112, 83, 153, 44, 59, 204, 49, 128, 84, 160, 227, 216, 149, 97, 102, ]), AccountId::new([ 30, 145, 107, 3, 207, 73, 192, 230, 160, 63, 238, 207, 18, 69, 54, 216, 103, 244, 92, 94, 124, 248, 42, 16, 141, 19, 119, 18, 14, 226, 140, 204, ]), ]; let signer_account_ids = tx.signer_account_ids(); assert_eq!(signer_account_ids, expected_signer_account_ids); } #[test] fn public_transaction_encoding_bytes_roundtrip() { let tx = transaction_for_tests(); let bytes = tx.to_bytes(); let tx_from_bytes = PublicTransaction::from_bytes(&bytes).unwrap(); assert_eq!(tx, tx_from_bytes); } #[test] fn hash_is_sha256_of_transaction_bytes() { let tx = transaction_for_tests(); let hash = tx.hash(); let expected_hash: [u8; 32] = { let bytes = tx.to_bytes(); let mut hasher = sha2::Sha256::new(); hasher.update(&bytes); hasher.finalize_fixed().into() }; assert_eq!(hash, expected_hash); } #[test] fn account_id_list_cant_have_duplicates() { let (key1, _, addr1, _) = keys_for_tests(); let state = state_for_tests(); let nonces = vec![0_u128.into(), 0_u128.into()]; let instruction = 1337; let message = Message::try_new( Program::authenticated_transfer_program().id(), vec![addr1, addr1], nonces, instruction, ) .unwrap(); let witness_set = WitnessSet::for_message(&message, &[&key1, &key1]); let tx = PublicTransaction::new(message, witness_set); let result = tx.validate_and_produce_public_state_diff(&state, 1, 0); assert!(matches!(result, Err(NssaError::InvalidInput(_)))); } #[test] fn number_of_nonces_must_match_number_of_signatures() { let (key1, key2, addr1, addr2) = keys_for_tests(); let state = state_for_tests(); let nonces = vec![0_u128.into()]; let instruction = 1337; let message = Message::try_new( Program::authenticated_transfer_program().id(), vec![addr1, addr2], nonces, instruction, ) .unwrap(); let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]); let tx = PublicTransaction::new(message, witness_set); let result = tx.validate_and_produce_public_state_diff(&state, 1, 0); assert!(matches!(result, Err(NssaError::InvalidInput(_)))); } #[test] fn all_signatures_must_be_valid() { let (key1, key2, addr1, addr2) = keys_for_tests(); let state = state_for_tests(); let nonces = vec![0_u128.into(), 0_u128.into()]; let instruction = 1337; let message = Message::try_new( Program::authenticated_transfer_program().id(), vec![addr1, addr2], nonces, instruction, ) .unwrap(); let mut witness_set = WitnessSet::for_message(&message, &[&key1, &key2]); witness_set.signatures_and_public_keys[0].0 = Signature::new_for_tests([1; 64]); let tx = PublicTransaction::new(message, witness_set); let result = tx.validate_and_produce_public_state_diff(&state, 1, 0); assert!(matches!(result, Err(NssaError::InvalidInput(_)))); } #[test] fn nonces_must_match_the_state_current_nonces() { let (key1, key2, addr1, addr2) = keys_for_tests(); let state = state_for_tests(); let nonces = vec![0_u128.into(), 1_u128.into()]; let instruction = 1337; let message = Message::try_new( Program::authenticated_transfer_program().id(), vec![addr1, addr2], nonces, instruction, ) .unwrap(); let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]); let tx = PublicTransaction::new(message, witness_set); let result = tx.validate_and_produce_public_state_diff(&state, 1, 0); assert!(matches!(result, Err(NssaError::InvalidInput(_)))); } #[test] fn program_id_must_belong_to_bulitin_program_ids() { let (key1, key2, addr1, addr2) = keys_for_tests(); let state = state_for_tests(); let nonces = vec![0_u128.into(), 0_u128.into()]; let instruction = 1337; let unknown_program_id = [0xdead_beef; 8]; let message = Message::try_new(unknown_program_id, vec![addr1, addr2], nonces, instruction).unwrap(); let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]); let tx = PublicTransaction::new(message, witness_set); let result = tx.validate_and_produce_public_state_diff(&state, 1, 0); assert!(matches!(result, Err(NssaError::InvalidInput(_)))); } }