add multi chain calls

This commit is contained in:
Sergio Chouhy 2025-11-12 19:08:46 -03:00
parent d69e8a292e
commit 2e582e7874
5 changed files with 24 additions and 13 deletions

View File

@ -25,7 +25,7 @@ pub struct ChainedCall {
pub struct ProgramOutput { pub struct ProgramOutput {
pub pre_states: Vec<AccountWithMetadata>, pub pre_states: Vec<AccountWithMetadata>,
pub post_states: Vec<Account>, pub post_states: Vec<Account>,
pub chained_call: Option<ChainedCall>, pub chained_call: Vec<ChainedCall>,
} }
pub fn read_nssa_inputs<T: DeserializeOwned>() -> ProgramInput<T> { pub fn read_nssa_inputs<T: DeserializeOwned>() -> ProgramInput<T> {
@ -42,7 +42,7 @@ pub fn write_nssa_outputs(pre_states: Vec<AccountWithMetadata>, post_states: Vec
let output = ProgramOutput { let output = ProgramOutput {
pre_states, pre_states,
post_states, post_states,
chained_call: None, chained_call: Vec::new(),
}; };
env::commit(&output); env::commit(&output);
} }
@ -50,7 +50,7 @@ pub fn write_nssa_outputs(pre_states: Vec<AccountWithMetadata>, post_states: Vec
pub fn write_nssa_outputs_with_chained_call( pub fn write_nssa_outputs_with_chained_call(
pre_states: Vec<AccountWithMetadata>, pre_states: Vec<AccountWithMetadata>,
post_states: Vec<Account>, post_states: Vec<Account>,
chained_call: Option<ChainedCall>, chained_call: Vec<ChainedCall>,
) { ) {
let output = ProgramOutput { let output = ProgramOutput {
pre_states, pre_states,

View File

@ -31,7 +31,7 @@ fn main() {
} = program_output; } = program_output;
// TODO: implement chained calls for privacy preserving transactions // TODO: implement chained calls for privacy preserving transactions
if chained_call.is_some() { if !chained_call.is_empty() {
panic!("Privacy preserving transactions do not support yet chained calls.") panic!("Privacy preserving transactions do not support yet chained calls.")
} }

View File

@ -105,6 +105,7 @@ impl PublicTransaction {
let mut program_id = message.program_id; let mut program_id = message.program_id;
let mut instruction_data = message.instruction_data.clone(); let mut instruction_data = message.instruction_data.clone();
let mut chained_calls = Vec::new();
for _i in 0..MAX_NUMBER_CHAINED_CALLS { for _i in 0..MAX_NUMBER_CHAINED_CALLS {
// Check the `program_id` corresponds to a deployed program // Check the `program_id` corresponds to a deployed program
@ -147,7 +148,9 @@ impl PublicTransaction {
state_diff.insert(pre.account_id, post.clone()); state_diff.insert(pre.account_id, post.clone());
} }
if let Some(next_chained_call) = program_output.chained_call { chained_calls.extend_from_slice(&program_output.chained_call);
if let Some(next_chained_call) = chained_calls.pop() {
program_id = next_chained_call.program_id; program_id = next_chained_call.program_id;
instruction_data = next_chained_call.instruction_data; instruction_data = next_chained_call.instruction_data;

View File

@ -2096,7 +2096,7 @@ pub mod tests {
let expected_to_post = Account { let expected_to_post = Account {
program_owner: Program::chain_caller().id(), program_owner: Program::chain_caller().id(),
balance: amount, balance: amount * 2, // The `chain_caller` chains the program twice
..Account::default() ..Account::default()
}; };
@ -2114,7 +2114,8 @@ pub mod tests {
let from_post = state.get_account_by_address(&from); let from_post = state.get_account_by_address(&from);
let to_post = state.get_account_by_address(&to); let to_post = state.get_account_by_address(&to);
assert_eq!(from_post.balance, initial_balance - amount); // The `chain_caller` program calls the program twice
assert_eq!(from_post.balance, initial_balance - 2 * amount);
assert_eq!(to_post, expected_to_post); assert_eq!(to_post, expected_to_post);
} }
} }

View File

@ -5,7 +5,7 @@ use risc0_zkvm::serde::to_vec;
type Instruction = (u128, ProgramId); type Instruction = (u128, ProgramId);
/// A program that calls another program. /// A program that calls another program twice.
/// It permutes the order of the input accounts on the subsequent call /// It permutes the order of the input accounts on the subsequent call
fn main() { fn main() {
let ProgramInput { let ProgramInput {
@ -20,11 +20,18 @@ fn main() {
let instruction_data = to_vec(&balance).unwrap(); let instruction_data = to_vec(&balance).unwrap();
let chained_call = Some(ChainedCall { let chained_call = vec![
program_id, ChainedCall {
instruction_data, program_id,
account_indices: vec![1, 0], // <- Account order permutation here instruction_data: instruction_data.clone(),
}); account_indices: vec![0, 1],
},
ChainedCall {
program_id,
instruction_data,
account_indices: vec![1, 0], // <- Account order permutation here
},
];
write_nssa_outputs_with_chained_call( write_nssa_outputs_with_chained_call(
vec![sender_pre.clone(), receiver_pre.clone()], vec![sender_pre.clone(), receiver_pre.clone()],