diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 900a476..d118d0c 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet}; use nssa_core::{ account::{Account, AccountWithMetadata}, address::Address, - program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution}, + program::{DEFAULT_PROGRAM_ID, validate_execution}, }; use sha2::{Digest, digest::FixedOutput}; @@ -89,7 +89,7 @@ impl PublicTransaction { } // Build pre_states for execution - let pre_states: Vec<_> = message + let mut input_pre_states: Vec<_> = message .addresses .iter() .map(|address| { @@ -101,66 +101,80 @@ impl PublicTransaction { }) .collect(); - let mut state_diff: HashMap = message - .addresses - .iter() - .cloned() - .zip(pre_states.iter().map(|pre| pre.account.clone())) - .collect(); + let mut state_diff: HashMap = HashMap::new(); - let mut chained_call = ChainedCall { - program_id: message.program_id, - instruction_data: message.instruction_data.clone(), - account_indices: (0..pre_states.len()).collect(), - }; + let mut program_id = message.program_id; + let mut instruction_data = message.instruction_data.clone(); for _i in 0..MAX_NUMBER_CHAINED_CALLS { // Check the `program_id` corresponds to a deployed program - let Some(program) = state.programs().get(&chained_call.program_id) else { + let Some(program) = state.programs().get(&program_id) else { return Err(NssaError::InvalidInput("Unknown program".into())); }; - let pre_states_chained_call = chained_call - .account_indices - .iter() - .map(|&i| { - pre_states - .get(i) - .ok_or_else(|| NssaError::InvalidInput("Invalid account indices".into())) - .cloned() - }) - .collect::, NssaError>>()?; + let mut program_output = program.execute(&input_pre_states, &instruction_data)?; - let mut program_output = - program.execute(&pre_states_chained_call, &chained_call.instruction_data)?; + // This check is equivalent to checking that the program output pre_states coinicide + // with the values in the public state or with any modifications to those values + // during the chain of calls. + if input_pre_states != program_output.pre_states { + return Err(NssaError::InvalidProgramBehavior); + } // Verify execution corresponds to a well-behaved program. // See the # Programs section for the definition of the `validate_execution` method. if !validate_execution( &program_output.pre_states, &program_output.post_states, - chained_call.program_id, + program_id, ) { return Err(NssaError::InvalidProgramBehavior); } + // The invoked program claims the accounts with default program id. for post in program_output.post_states.iter_mut() { - // The invoked program claims the accounts with default program id. if post.program_owner == DEFAULT_PROGRAM_ID { - post.program_owner = chained_call.program_id; + post.program_owner = program_id; } } + // Update the state diff for (pre, post) in program_output .pre_states .iter() - .zip(program_output.post_states) + .zip(program_output.post_states.iter()) { - state_diff.insert(pre.account_id, post); + state_diff.insert(pre.account_id, post.clone()); } if let Some(next_chained_call) = program_output.chained_call { - chained_call = next_chained_call; + program_id = next_chained_call.program_id; + instruction_data = next_chained_call.instruction_data; + + // Build post states with metadata for next call + let mut post_states_with_metadata = Vec::new(); + for (pre, post) in program_output + .pre_states + .iter() + .zip(program_output.post_states) + { + let mut post_with_metadata = pre.clone(); + post_with_metadata.account = post.clone(); + post_states_with_metadata.push(post_with_metadata); + } + + input_pre_states = next_chained_call + .account_indices + .iter() + .map(|&i| { + post_states_with_metadata + .get(i) + .ok_or_else(|| { + NssaError::InvalidInput("Invalid account indices".into()) + }) + .cloned() + }) + .collect::, NssaError>>()?; } else { break; }; diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 2119929..4120824 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -6,9 +6,7 @@ use crate::{ }; use nssa_core::{ Commitment, CommitmentSetDigest, DUMMY_COMMITMENT, MembershipProof, Nullifier, - account::Account, - address::Address, - program::ProgramId, + account::Account, address::Address, program::ProgramId, }; use std::collections::{HashMap, HashSet}; @@ -2097,7 +2095,7 @@ pub mod tests { (amount, Program::authenticated_transfer_program().id()); let expected_to_post = Account { - program_owner: Program::authenticated_transfer_program().id(), + program_owner: Program::chain_caller().id(), balance: amount, ..Account::default() };