refactor call stack execution loop

This commit is contained in:
Sergio Chouhy 2025-11-27 09:54:14 -03:00
parent d1d2292028
commit 577fad6d5f

View File

@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
use nssa_core::{ use nssa_core::{
account::{Account, AccountId, AccountWithMetadata}, account::{Account, AccountId, AccountWithMetadata},
program::{DEFAULT_PROGRAM_ID, validate_execution}, program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
}; };
use sha2::{Digest, digest::FixedOutput}; use sha2::{Digest, digest::FixedOutput};
@ -88,7 +88,7 @@ impl PublicTransaction {
} }
// Build pre_states for execution // Build pre_states for execution
let mut input_pre_states: Vec<_> = message let input_pre_states: Vec<_> = message
.account_ids .account_ids
.iter() .iter()
.map(|account_id| { .map(|account_id| {
@ -102,17 +102,27 @@ impl PublicTransaction {
let mut state_diff: HashMap<AccountId, Account> = HashMap::new(); let mut state_diff: HashMap<AccountId, Account> = HashMap::new();
let mut program_id = message.program_id; let initial_call = ChainedCall {
let mut instruction_data = message.instruction_data.clone(); program_id: message.program_id,
let mut chained_calls = Vec::new(); instruction_data: message.instruction_data.clone(),
pre_states: input_pre_states,
};
let mut chained_calls = Vec::from_iter([initial_call]);
let mut chain_calls_counter = 0;
while let Some(chained_call) = chained_calls.pop() {
if chain_calls_counter > MAX_NUMBER_CHAINED_CALLS {
return Err(NssaError::MaxChainedCallsDepthExceeded);
}
for _i in 0..MAX_NUMBER_CHAINED_CALLS {
// Check the `program_id` corresponds to a deployed program // Check the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&program_id) else { let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into())); return Err(NssaError::InvalidInput("Unknown program".into()));
}; };
let mut program_output = program.execute(&input_pre_states, &instruction_data)?; let mut program_output =
program.execute(&chained_call.pre_states, &chained_call.instruction_data)?;
for pre in program_output.pre_states.iter() { for pre in program_output.pre_states.iter() {
let account_id = pre.account_id; let account_id = pre.account_id;
@ -137,7 +147,7 @@ impl PublicTransaction {
if !validate_execution( if !validate_execution(
&program_output.pre_states, &program_output.pre_states,
&program_output.post_states, &program_output.post_states,
program_id, chained_call.program_id,
) { ) {
return Err(NssaError::InvalidProgramBehavior); return Err(NssaError::InvalidProgramBehavior);
} }
@ -145,7 +155,7 @@ impl PublicTransaction {
// The invoked program claims the accounts with default program id. // The invoked program claims the accounts with default program id.
for post in program_output.post_states.iter_mut() { for post in program_output.post_states.iter_mut() {
if post.program_owner == DEFAULT_PROGRAM_ID { if post.program_owner == DEFAULT_PROGRAM_ID {
post.program_owner = program_id; post.program_owner = chained_call.program_id;
} }
} }
@ -159,21 +169,10 @@ impl PublicTransaction {
} }
chained_calls.extend_from_slice(&program_output.chained_calls); chained_calls.extend_from_slice(&program_output.chained_calls);
chain_calls_counter += 1;
if let Some(next_chained_call) = chained_calls.pop() {
program_id = next_chained_call.program_id;
instruction_data = next_chained_call.instruction_data;
input_pre_states = next_chained_call.pre_states;
} else {
break;
};
} }
if chained_calls.is_empty() { Ok(state_diff)
Ok(state_diff)
} else {
Err(NssaError::MaxChainedCallsDepthExceeded)
}
} }
} }