diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index a96d3cf..6582984 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -25,7 +25,7 @@ pub struct ChainedCall { pub struct ProgramOutput { pub pre_states: Vec, pub post_states: Vec, - pub chained_call: Option + pub chained_call: Option, } pub fn read_nssa_inputs() -> ProgramInput { @@ -42,7 +42,20 @@ pub fn write_nssa_outputs(pre_states: Vec, post_states: Vec let output = ProgramOutput { pre_states, post_states, - chained_call: None + chained_call: None, + }; + env::commit(&output); +} + +pub fn write_nssa_outputs_with_chained_call( + pre_states: Vec, + post_states: Vec, + chained_call: Option, +) { + let output = ProgramOutput { + pre_states, + post_states, + chained_call, }; env::commit(&output); } diff --git a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs index 345198a..d8ed15d 100644 --- a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -30,9 +30,9 @@ fn main() { chained_call, } = program_output; - // TODO: implement tail calls for privacy preserving transactions + // TODO: implement chained calls for privacy preserving transactions if chained_call.is_some() { - panic!("Privacy preserving transactions do not support yet tail calls.") + panic!("Privacy preserving transactions do not support yet chained calls.") } // Check that there are no repeated account ids diff --git a/nssa/src/program.rs b/nssa/src/program.rs index 877b01b..11eb413 100644 --- a/nssa/src/program.rs +++ b/nssa/src/program.rs @@ -1,6 +1,6 @@ use crate::program_methods::{AUTHENTICATED_TRANSFER_ELF, PINATA_ELF, TOKEN_ELF}; use nssa_core::{ - account::{Account, AccountWithMetadata}, + account::AccountWithMetadata, program::{InstructionData, ProgramId, ProgramOutput}, }; @@ -107,11 +107,11 @@ impl Program { #[cfg(test)] mod tests { - use nssa_core::account::{Account, AccountId, AccountWithMetadata}; - use program_methods::{ + use crate::program_methods::{ AUTHENTICATED_TRANSFER_ELF, AUTHENTICATED_TRANSFER_ID, PINATA_ELF, PINATA_ID, TOKEN_ELF, TOKEN_ID, }; + use nssa_core::account::{Account, AccountId, AccountWithMetadata}; use crate::program::Program; @@ -195,6 +195,15 @@ mod tests { elf: BURNER_ELF.to_vec(), } } + + pub fn chain_caller() -> Self { + use test_program_methods::{CHAIN_CALLER_ELF, CHAIN_CALLER_ID}; + + Program { + id: CHAIN_CALLER_ID, + elf: CHAIN_CALLER_ELF.to_vec(), + } + } } #[test] diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 7079074..70d22a3 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -114,7 +114,7 @@ impl PublicTransaction { account_indices: (0..pre_states.len()).collect(), }; - for _ in 0..MAX_NUMBER_CHAINED_CALLS { + for _i in 0..MAX_NUMBER_CHAINED_CALLS { // Check the `program_id` corresponds to a deployed program let Some(program) = state.programs().get(&chained_call.program_id) else { return Err(NssaError::InvalidInput("Unknown program".into())); @@ -139,7 +139,7 @@ impl PublicTransaction { if validate_execution( &program_output.pre_states, &program_output.post_states, - message.program_id, + chained_call.program_id, ) { for (pre, post) in program_output .pre_states diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 83183f5..5a0b87b 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -263,6 +263,7 @@ pub mod tests { Commitment, Nullifier, NullifierPublicKey, NullifierSecretKey, SharedSecretKey, account::{Account, AccountId, AccountWithMetadata, Nonce}, encryption::{EphemeralPublicKey, IncomingViewingPublicKey, Scalar}, + program::ProgramId, }; fn transfer_transaction( @@ -475,6 +476,7 @@ pub mod tests { self.insert_program(Program::data_changer()); self.insert_program(Program::minter()); self.insert_program(Program::burner()); + self.insert_program(Program::chain_caller()); self } @@ -2045,4 +2047,38 @@ pub mod tests { assert!(matches!(result, Err(NssaError::CircuitProvingError(_)))); } + + #[test] + fn test_chained_call() { + let program = Program::chain_caller(); + let key = PrivateKey::try_new([1; 32]).unwrap(); + let address = Address::from(&PublicKey::new_from_private_key(&key)); + let initial_balance = 100; + let initial_data = [(address, initial_balance)]; + let mut state = + V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs(); + let from = address; + let from_key = key; + let to = Address::new([2; 32]); + let amount: u128 = 37; + let instruction: (u128, ProgramId) = + (amount, Program::authenticated_transfer_program().id()); + + let message = public_transaction::Message::try_new( + program.id(), + vec![to, from], + vec![0], + instruction, + ) + .unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&from_key]); + let tx = PublicTransaction::new(message, witness_set); + + state.transition_from_public_transaction(&tx).unwrap(); + + let from_post = state.get_account_by_address(&from); + let to_post = state.get_account_by_address(&to); + assert_eq!(from_post.balance, initial_balance - amount); + assert_eq!(to_post.balance, amount); + } } diff --git a/nssa/test_program_methods/guest/Cargo.lock b/nssa/test_program_methods/guest/Cargo.lock index 8cb2bec..d7e5b67 100644 --- a/nssa/test_program_methods/guest/Cargo.lock +++ b/nssa/test_program_methods/guest/Cargo.lock @@ -1824,6 +1824,8 @@ name = "programs" version = "0.1.0" dependencies = [ "nssa-core", + "risc0-zkvm", + "serde", ] [[package]] diff --git a/nssa/test_program_methods/guest/Cargo.toml b/nssa/test_program_methods/guest/Cargo.toml index 2289292..9e5f543 100644 --- a/nssa/test_program_methods/guest/Cargo.toml +++ b/nssa/test_program_methods/guest/Cargo.toml @@ -6,4 +6,6 @@ edition = "2024" [workspace] [dependencies] +risc0-zkvm = { version = "3.0.3", features = ['std'] } nssa-core = { path = "../../core" } +serde = { version = "1.0.219", default-features = false } diff --git a/nssa/test_program_methods/guest/src/bin/chain_caller.rs b/nssa/test_program_methods/guest/src/bin/chain_caller.rs new file mode 100644 index 0000000..321d032 --- /dev/null +++ b/nssa/test_program_methods/guest/src/bin/chain_caller.rs @@ -0,0 +1,32 @@ +use nssa_core::program::{ + ChainedCall, ProgramId, ProgramInput, read_nssa_inputs, write_nssa_outputs_with_chained_call, +}; +use risc0_zkvm::serde::to_vec; + +type Instruction = (u128, ProgramId); + +fn main() { + let ProgramInput { + pre_states, + instruction: (balance, program_id), + } = read_nssa_inputs::(); + + let [sender_pre, receiver_pre] = match pre_states.try_into() { + Ok(array) => array, + Err(_) => return, + }; + + let instruction_data = to_vec(&balance).unwrap(); + + let chained_call = Some(ChainedCall { + program_id, + instruction_data, + account_indices: vec![1, 0], + }); + + write_nssa_outputs_with_chained_call( + vec![sender_pre.clone(), receiver_pre.clone()], + vec![sender_pre.account, receiver_pre.account], + chained_call, + ); +}