diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs
index 900a476..d118d0c 100644
--- a/nssa/src/public_transaction/transaction.rs
+++ b/nssa/src/public_transaction/transaction.rs
@@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet};
use nssa_core::{
account::{Account, AccountWithMetadata},
address::Address,
- program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
+ program::{DEFAULT_PROGRAM_ID, validate_execution},
};
use sha2::{Digest, digest::FixedOutput};
@@ -89,7 +89,7 @@ impl PublicTransaction {
}
// Build pre_states for execution
- let pre_states: Vec<_> = message
+ let mut input_pre_states: Vec<_> = message
.addresses
.iter()
.map(|address| {
@@ -101,66 +101,80 @@ impl PublicTransaction {
})
.collect();
- let mut state_diff: HashMap
= message
- .addresses
- .iter()
- .cloned()
- .zip(pre_states.iter().map(|pre| pre.account.clone()))
- .collect();
+ let mut state_diff: HashMap = HashMap::new();
- let mut chained_call = ChainedCall {
- program_id: message.program_id,
- instruction_data: message.instruction_data.clone(),
- account_indices: (0..pre_states.len()).collect(),
- };
+ let mut program_id = message.program_id;
+ let mut instruction_data = message.instruction_data.clone();
for _i in 0..MAX_NUMBER_CHAINED_CALLS {
// Check the `program_id` corresponds to a deployed program
- let Some(program) = state.programs().get(&chained_call.program_id) else {
+ let Some(program) = state.programs().get(&program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
};
- let pre_states_chained_call = chained_call
- .account_indices
- .iter()
- .map(|&i| {
- pre_states
- .get(i)
- .ok_or_else(|| NssaError::InvalidInput("Invalid account indices".into()))
- .cloned()
- })
- .collect::, NssaError>>()?;
+ let mut program_output = program.execute(&input_pre_states, &instruction_data)?;
- let mut program_output =
- program.execute(&pre_states_chained_call, &chained_call.instruction_data)?;
+ // This check is equivalent to checking that the program output pre_states coinicide
+ // with the values in the public state or with any modifications to those values
+ // during the chain of calls.
+ if input_pre_states != program_output.pre_states {
+ return Err(NssaError::InvalidProgramBehavior);
+ }
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
if !validate_execution(
&program_output.pre_states,
&program_output.post_states,
- chained_call.program_id,
+ program_id,
) {
return Err(NssaError::InvalidProgramBehavior);
}
+ // The invoked program claims the accounts with default program id.
for post in program_output.post_states.iter_mut() {
- // The invoked program claims the accounts with default program id.
if post.program_owner == DEFAULT_PROGRAM_ID {
- post.program_owner = chained_call.program_id;
+ post.program_owner = program_id;
}
}
+ // Update the state diff
for (pre, post) in program_output
.pre_states
.iter()
- .zip(program_output.post_states)
+ .zip(program_output.post_states.iter())
{
- state_diff.insert(pre.account_id, post);
+ state_diff.insert(pre.account_id, post.clone());
}
if let Some(next_chained_call) = program_output.chained_call {
- chained_call = next_chained_call;
+ program_id = next_chained_call.program_id;
+ instruction_data = next_chained_call.instruction_data;
+
+ // Build post states with metadata for next call
+ let mut post_states_with_metadata = Vec::new();
+ for (pre, post) in program_output
+ .pre_states
+ .iter()
+ .zip(program_output.post_states)
+ {
+ let mut post_with_metadata = pre.clone();
+ post_with_metadata.account = post.clone();
+ post_states_with_metadata.push(post_with_metadata);
+ }
+
+ input_pre_states = next_chained_call
+ .account_indices
+ .iter()
+ .map(|&i| {
+ post_states_with_metadata
+ .get(i)
+ .ok_or_else(|| {
+ NssaError::InvalidInput("Invalid account indices".into())
+ })
+ .cloned()
+ })
+ .collect::, NssaError>>()?;
} else {
break;
};
diff --git a/nssa/src/state.rs b/nssa/src/state.rs
index 2119929..4120824 100644
--- a/nssa/src/state.rs
+++ b/nssa/src/state.rs
@@ -6,9 +6,7 @@ use crate::{
};
use nssa_core::{
Commitment, CommitmentSetDigest, DUMMY_COMMITMENT, MembershipProof, Nullifier,
- account::Account,
- address::Address,
- program::ProgramId,
+ account::Account, address::Address, program::ProgramId,
};
use std::collections::{HashMap, HashSet};
@@ -2097,7 +2095,7 @@ pub mod tests {
(amount, Program::authenticated_transfer_program().id());
let expected_to_post = Account {
- program_owner: Program::authenticated_transfer_program().id(),
+ program_owner: Program::chain_caller().id(),
balance: amount,
..Account::default()
};