diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 3ecee30..d36cf8f 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -25,7 +25,7 @@ pub struct ChainedCall { pub struct ProgramOutput { pub pre_states: Vec, pub post_states: Vec, - pub chained_call: Option, + pub chained_call: Vec, } pub fn read_nssa_inputs() -> ProgramInput { @@ -42,7 +42,7 @@ pub fn write_nssa_outputs(pre_states: Vec, post_states: Vec let output = ProgramOutput { pre_states, post_states, - chained_call: None, + chained_call: Vec::new(), }; env::commit(&output); } @@ -50,7 +50,7 @@ pub fn write_nssa_outputs(pre_states: Vec, post_states: Vec pub fn write_nssa_outputs_with_chained_call( pre_states: Vec, post_states: Vec, - chained_call: Option, + chained_call: Vec, ) { let output = ProgramOutput { pre_states, diff --git a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs index d8ed15d..530c87e 100644 --- a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -31,7 +31,7 @@ fn main() { } = program_output; // TODO: implement chained calls for privacy preserving transactions - if chained_call.is_some() { + if !chained_call.is_empty() { panic!("Privacy preserving transactions do not support yet chained calls.") } diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index d118d0c..cce4ffd 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -105,6 +105,7 @@ impl PublicTransaction { let mut program_id = message.program_id; let mut instruction_data = message.instruction_data.clone(); + let mut chained_calls = Vec::new(); for _i in 0..MAX_NUMBER_CHAINED_CALLS { // Check the `program_id` corresponds to a deployed program @@ -147,7 +148,9 @@ impl PublicTransaction { state_diff.insert(pre.account_id, post.clone()); } - if let Some(next_chained_call) = program_output.chained_call { + chained_calls.extend_from_slice(&program_output.chained_call); + + if let Some(next_chained_call) = chained_calls.pop() { program_id = next_chained_call.program_id; instruction_data = next_chained_call.instruction_data; diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 4120824..b0f60eb 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -2096,7 +2096,7 @@ pub mod tests { let expected_to_post = Account { program_owner: Program::chain_caller().id(), - balance: amount, + balance: amount * 2, // The `chain_caller` chains the program twice ..Account::default() }; @@ -2114,7 +2114,8 @@ pub mod tests { let from_post = state.get_account_by_address(&from); let to_post = state.get_account_by_address(&to); - assert_eq!(from_post.balance, initial_balance - amount); + // The `chain_caller` program calls the program twice + assert_eq!(from_post.balance, initial_balance - 2 * amount); assert_eq!(to_post, expected_to_post); } } diff --git a/nssa/test_program_methods/guest/src/bin/chain_caller.rs b/nssa/test_program_methods/guest/src/bin/chain_caller.rs index dfd77b1..c4a548b 100644 --- a/nssa/test_program_methods/guest/src/bin/chain_caller.rs +++ b/nssa/test_program_methods/guest/src/bin/chain_caller.rs @@ -5,7 +5,7 @@ use risc0_zkvm::serde::to_vec; type Instruction = (u128, ProgramId); -/// A program that calls another program. +/// A program that calls another program twice. /// It permutes the order of the input accounts on the subsequent call fn main() { let ProgramInput { @@ -20,11 +20,18 @@ fn main() { let instruction_data = to_vec(&balance).unwrap(); - let chained_call = Some(ChainedCall { - program_id, - instruction_data, - account_indices: vec![1, 0], // <- Account order permutation here - }); + let chained_call = vec![ + ChainedCall { + program_id, + instruction_data: instruction_data.clone(), + account_indices: vec![0, 1], + }, + ChainedCall { + program_id, + instruction_data, + account_indices: vec![1, 0], // <- Account order permutation here + }, + ]; write_nssa_outputs_with_chained_call( vec![sender_pre.clone(), receiver_pre.clone()],