add tail-chain logic for public transactions

This commit is contained in:
Sergio Chouhy 2025-10-29 01:51:09 -03:00
parent 6d9d6b3d28
commit 3a27719392
4 changed files with 68 additions and 18 deletions

View File

@ -27,9 +27,14 @@ fn main() {
let ProgramOutput {
pre_states,
post_states,
chained_call: _,
chained_call,
} = program_output;
// TODO: implement tail calls for privacy preserving transactions
if chained_call.is_some() {
panic!("Privacy preserving transactions do not support yet tail calls.")
}
// Check that there are no repeated account ids
if !validate_uniqueness_of_account_ids(&pre_states) {
panic!("Repeated account ids found")

View File

@ -48,7 +48,7 @@ impl Program {
&self,
pre_states: &[AccountWithMetadata],
instruction_data: &InstructionData,
) -> Result<Vec<Account>, NssaError> {
) -> Result<ProgramOutput, NssaError> {
// Write inputs to the program
let mut env_builder = ExecutorEnv::builder();
env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION));
@ -62,12 +62,12 @@ impl Program {
.map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?;
// Get outputs
let ProgramOutput { post_states, .. } = session_info
let program_output = session_info
.journal
.decode()
.map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?;
Ok(post_states)
Ok(program_output)
}
/// Writes inputs to `env_builder` in the order expected by the programs
@ -221,12 +221,12 @@ mod tests {
balance: balance_to_move,
..Account::default()
};
let [sender_post, recipient_post] = program
let program_output = program
.execute(&[sender, recipient], &instruction_data)
.unwrap()
.try_into()
.unwrap();
let [sender_post, recipient_post] = program_output.post_states.try_into().unwrap();
assert_eq!(sender_post, expected_sender_post);
assert_eq!(recipient_post, expected_recipient_post);
}

View File

@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet};
use nssa_core::{
account::{Account, AccountWithMetadata},
address::Address,
program::validate_execution,
program::{ChainedCall, validate_execution},
};
use sha2::{Digest, digest::FixedOutput};
@ -18,6 +18,7 @@ pub struct PublicTransaction {
message: Message,
witness_set: WitnessSet,
}
const MAX_NUMBER_CHAINED_CALLS: usize = 10;
impl PublicTransaction {
pub fn new(message: Message, witness_set: WitnessSet) -> Self {
@ -100,21 +101,65 @@ impl PublicTransaction {
})
.collect();
// Check the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&message.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
let mut state_diff: HashMap<Address, Account> = message
.addresses
.iter()
.cloned()
.zip(pre_states.iter().map(|pre| pre.account.clone()))
.collect();
let mut chained_call = ChainedCall {
program_id: message.program_id,
instruction_data: message.instruction_data.clone(),
account_indices: (0..pre_states.len()).collect(),
};
// // Execute program
let post_states = program.execute(&pre_states, &message.instruction_data)?;
for _ in 0..MAX_NUMBER_CHAINED_CALLS {
// Check the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
};
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
if !validate_execution(&pre_states, &post_states, message.program_id) {
return Err(NssaError::InvalidProgramBehavior);
let pre_states_chained_call = chained_call
.account_indices
.iter()
.map(|&i| {
pre_states
.get(i)
.ok_or_else(|| NssaError::InvalidInput("Invalid account indices".into()))
.cloned()
})
.collect::<Result<Vec<_>, NssaError>>()?;
let program_output =
program.execute(&pre_states_chained_call, &chained_call.instruction_data)?;
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
if validate_execution(
&program_output.pre_states,
&program_output.post_states,
message.program_id,
) {
for (pre, post) in program_output
.pre_states
.iter()
.zip(program_output.post_states)
{
state_diff.insert(pre.account_id, post);
}
} else {
return Err(NssaError::InvalidProgramBehavior);
}
if let Some(next_chained_call) = program_output.chained_call {
chained_call = next_chained_call;
} else {
break;
};
}
Ok(message.addresses.iter().cloned().zip(post_states).collect())
Ok(state_diff)
}
}