add tail-chain logic for public transactions

This commit is contained in:
Sergio Chouhy 2025-10-29 01:51:09 -03:00
parent 6d9d6b3d28
commit 3a27719392
4 changed files with 68 additions and 18 deletions

View File

@ -27,9 +27,14 @@ fn main() {
let ProgramOutput { let ProgramOutput {
pre_states, pre_states,
post_states, post_states,
chained_call: _, chained_call,
} = program_output; } = program_output;
// TODO: implement tail calls for privacy preserving transactions
if chained_call.is_some() {
panic!("Privacy preserving transactions do not support yet tail calls.")
}
// Check that there are no repeated account ids // Check that there are no repeated account ids
if !validate_uniqueness_of_account_ids(&pre_states) { if !validate_uniqueness_of_account_ids(&pre_states) {
panic!("Repeated account ids found") panic!("Repeated account ids found")

View File

@ -48,7 +48,7 @@ impl Program {
&self, &self,
pre_states: &[AccountWithMetadata], pre_states: &[AccountWithMetadata],
instruction_data: &InstructionData, instruction_data: &InstructionData,
) -> Result<Vec<Account>, NssaError> { ) -> Result<ProgramOutput, NssaError> {
// Write inputs to the program // Write inputs to the program
let mut env_builder = ExecutorEnv::builder(); let mut env_builder = ExecutorEnv::builder();
env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION)); env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION));
@ -62,12 +62,12 @@ impl Program {
.map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?; .map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?;
// Get outputs // Get outputs
let ProgramOutput { post_states, .. } = session_info let program_output = session_info
.journal .journal
.decode() .decode()
.map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?; .map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?;
Ok(post_states) Ok(program_output)
} }
/// Writes inputs to `env_builder` in the order expected by the programs /// Writes inputs to `env_builder` in the order expected by the programs
@ -221,12 +221,12 @@ mod tests {
balance: balance_to_move, balance: balance_to_move,
..Account::default() ..Account::default()
}; };
let [sender_post, recipient_post] = program let program_output = program
.execute(&[sender, recipient], &instruction_data) .execute(&[sender, recipient], &instruction_data)
.unwrap()
.try_into()
.unwrap(); .unwrap();
let [sender_post, recipient_post] = program_output.post_states.try_into().unwrap();
assert_eq!(sender_post, expected_sender_post); assert_eq!(sender_post, expected_sender_post);
assert_eq!(recipient_post, expected_recipient_post); assert_eq!(recipient_post, expected_recipient_post);
} }

View File

@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet};
use nssa_core::{ use nssa_core::{
account::{Account, AccountWithMetadata}, account::{Account, AccountWithMetadata},
address::Address, address::Address,
program::validate_execution, program::{ChainedCall, validate_execution},
}; };
use sha2::{Digest, digest::FixedOutput}; use sha2::{Digest, digest::FixedOutput};
@ -18,6 +18,7 @@ pub struct PublicTransaction {
message: Message, message: Message,
witness_set: WitnessSet, witness_set: WitnessSet,
} }
const MAX_NUMBER_CHAINED_CALLS: usize = 10;
impl PublicTransaction { impl PublicTransaction {
pub fn new(message: Message, witness_set: WitnessSet) -> Self { pub fn new(message: Message, witness_set: WitnessSet) -> Self {
@ -100,21 +101,65 @@ impl PublicTransaction {
}) })
.collect(); .collect();
// Check the `program_id` corresponds to a deployed program let mut state_diff: HashMap<Address, Account> = message
let Some(program) = state.programs().get(&message.program_id) else { .addresses
return Err(NssaError::InvalidInput("Unknown program".into())); .iter()
.cloned()
.zip(pre_states.iter().map(|pre| pre.account.clone()))
.collect();
let mut chained_call = ChainedCall {
program_id: message.program_id,
instruction_data: message.instruction_data.clone(),
account_indices: (0..pre_states.len()).collect(),
}; };
// // Execute program for _ in 0..MAX_NUMBER_CHAINED_CALLS {
let post_states = program.execute(&pre_states, &message.instruction_data)?; // Check the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
};
// Verify execution corresponds to a well-behaved program. let pre_states_chained_call = chained_call
// See the # Programs section for the definition of the `validate_execution` method. .account_indices
if !validate_execution(&pre_states, &post_states, message.program_id) { .iter()
return Err(NssaError::InvalidProgramBehavior); .map(|&i| {
pre_states
.get(i)
.ok_or_else(|| NssaError::InvalidInput("Invalid account indices".into()))
.cloned()
})
.collect::<Result<Vec<_>, NssaError>>()?;
let program_output =
program.execute(&pre_states_chained_call, &chained_call.instruction_data)?;
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
if validate_execution(
&program_output.pre_states,
&program_output.post_states,
message.program_id,
) {
for (pre, post) in program_output
.pre_states
.iter()
.zip(program_output.post_states)
{
state_diff.insert(pre.account_id, post);
}
} else {
return Err(NssaError::InvalidProgramBehavior);
}
if let Some(next_chained_call) = program_output.chained_call {
chained_call = next_chained_call;
} else {
break;
};
} }
Ok(message.addresses.iter().cloned().zip(post_states).collect()) Ok(state_diff)
} }
} }