This commit is contained in:
Sergio Chouhy 2025-10-30 13:47:52 -03:00
parent 0fb72e452f
commit 12974f6f6b
2 changed files with 48 additions and 36 deletions

View File

@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet};
use nssa_core::{
account::{Account, AccountWithMetadata},
address::Address,
program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
program::{DEFAULT_PROGRAM_ID, validate_execution},
};
use sha2::{Digest, digest::FixedOutput};
@ -89,7 +89,7 @@ impl PublicTransaction {
}
// Build pre_states for execution
let pre_states: Vec<_> = message
let mut input_pre_states: Vec<_> = message
.addresses
.iter()
.map(|address| {
@ -101,66 +101,80 @@ impl PublicTransaction {
})
.collect();
let mut state_diff: HashMap<Address, Account> = message
.addresses
.iter()
.cloned()
.zip(pre_states.iter().map(|pre| pre.account.clone()))
.collect();
let mut state_diff: HashMap<Address, Account> = HashMap::new();
let mut chained_call = ChainedCall {
program_id: message.program_id,
instruction_data: message.instruction_data.clone(),
account_indices: (0..pre_states.len()).collect(),
};
let mut program_id = message.program_id;
let mut instruction_data = message.instruction_data.clone();
for _i in 0..MAX_NUMBER_CHAINED_CALLS {
// Check the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else {
let Some(program) = state.programs().get(&program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
};
let pre_states_chained_call = chained_call
.account_indices
.iter()
.map(|&i| {
pre_states
.get(i)
.ok_or_else(|| NssaError::InvalidInput("Invalid account indices".into()))
.cloned()
})
.collect::<Result<Vec<_>, NssaError>>()?;
let mut program_output = program.execute(&input_pre_states, &instruction_data)?;
let mut program_output =
program.execute(&pre_states_chained_call, &chained_call.instruction_data)?;
// This check is equivalent to checking that the program output pre_states coinicide
// with the values in the public state or with any modifications to those values
// during the chain of calls.
if input_pre_states != program_output.pre_states {
return Err(NssaError::InvalidProgramBehavior);
}
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
if !validate_execution(
&program_output.pre_states,
&program_output.post_states,
chained_call.program_id,
program_id,
) {
return Err(NssaError::InvalidProgramBehavior);
}
// The invoked program claims the accounts with default program id.
for post in program_output.post_states.iter_mut() {
// The invoked program claims the accounts with default program id.
if post.program_owner == DEFAULT_PROGRAM_ID {
post.program_owner = chained_call.program_id;
post.program_owner = program_id;
}
}
// Update the state diff
for (pre, post) in program_output
.pre_states
.iter()
.zip(program_output.post_states)
.zip(program_output.post_states.iter())
{
state_diff.insert(pre.account_id, post);
state_diff.insert(pre.account_id, post.clone());
}
if let Some(next_chained_call) = program_output.chained_call {
chained_call = next_chained_call;
program_id = next_chained_call.program_id;
instruction_data = next_chained_call.instruction_data;
// Build post states with metadata for next call
let mut post_states_with_metadata = Vec::new();
for (pre, post) in program_output
.pre_states
.iter()
.zip(program_output.post_states)
{
let mut post_with_metadata = pre.clone();
post_with_metadata.account = post.clone();
post_states_with_metadata.push(post_with_metadata);
}
input_pre_states = next_chained_call
.account_indices
.iter()
.map(|&i| {
post_states_with_metadata
.get(i)
.ok_or_else(|| {
NssaError::InvalidInput("Invalid account indices".into())
})
.cloned()
})
.collect::<Result<Vec<_>, NssaError>>()?;
} else {
break;
};

View File

@ -6,9 +6,7 @@ use crate::{
};
use nssa_core::{
Commitment, CommitmentSetDigest, DUMMY_COMMITMENT, MembershipProof, Nullifier,
account::Account,
address::Address,
program::ProgramId,
account::Account, address::Address, program::ProgramId,
};
use std::collections::{HashMap, HashSet};
@ -2097,7 +2095,7 @@ pub mod tests {
(amount, Program::authenticated_transfer_program().id());
let expected_to_post = Account {
program_owner: Program::authenticated_transfer_program().id(),
program_owner: Program::chain_caller().id(),
balance: amount,
..Account::default()
};