diff --git a/integration_tests/src/data_changer.bin b/integration_tests/src/data_changer.bin index d201f91..c4fbec0 100644 Binary files a/integration_tests/src/data_changer.bin and b/integration_tests/src/data_changer.bin differ diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 82023f3..3ecee30 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -12,11 +12,20 @@ pub struct ProgramInput { pub instruction: T, } +#[derive(Serialize, Deserialize, Clone)] +#[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))] +pub struct ChainedCall { + pub program_id: ProgramId, + pub instruction_data: InstructionData, + pub account_indices: Vec, +} + #[derive(Serialize, Deserialize, Clone)] #[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))] pub struct ProgramOutput { pub pre_states: Vec, pub post_states: Vec, + pub chained_call: Option, } pub fn read_nssa_inputs() -> ProgramInput { @@ -33,6 +42,20 @@ pub fn write_nssa_outputs(pre_states: Vec, post_states: Vec let output = ProgramOutput { pre_states, post_states, + chained_call: None, + }; + env::commit(&output); +} + +pub fn write_nssa_outputs_with_chained_call( + pre_states: Vec, + post_states: Vec, + chained_call: Option, +) { + let output = ProgramOutput { + pre_states, + post_states, + chained_call, }; env::commit(&output); } @@ -79,9 +102,14 @@ pub fn validate_execution( { return false; } + + // 6. If a post state has default program owner, the pre state must have been a default account + if post.program_owner == DEFAULT_PROGRAM_ID && pre.account != Account::default() { + return false; + } } - // 6. Total balance is preserved + // 7. Total balance is preserved let total_balance_pre_states: u128 = pre_states.iter().map(|pre| pre.account.balance).sum(); let total_balance_post_states: u128 = post_states.iter().map(|post| post.balance).sum(); if total_balance_pre_states != total_balance_post_states { diff --git a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs index a1aa8c9..d8ed15d 100644 --- a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -27,8 +27,14 @@ fn main() { let ProgramOutput { pre_states, post_states, + chained_call, } = program_output; + // TODO: implement chained calls for privacy preserving transactions + if chained_call.is_some() { + panic!("Privacy preserving transactions do not support yet chained calls.") + } + // Check that there are no repeated account ids if !validate_uniqueness_of_account_ids(&pre_states) { panic!("Repeated account ids found") diff --git a/nssa/src/program.rs b/nssa/src/program.rs index 7771aaf..11eb413 100644 --- a/nssa/src/program.rs +++ b/nssa/src/program.rs @@ -1,6 +1,6 @@ use crate::program_methods::{AUTHENTICATED_TRANSFER_ELF, PINATA_ELF, TOKEN_ELF}; use nssa_core::{ - account::{Account, AccountWithMetadata}, + account::AccountWithMetadata, program::{InstructionData, ProgramId, ProgramOutput}, }; @@ -48,7 +48,7 @@ impl Program { &self, pre_states: &[AccountWithMetadata], instruction_data: &InstructionData, - ) -> Result, NssaError> { + ) -> Result { // Write inputs to the program let mut env_builder = ExecutorEnv::builder(); env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION)); @@ -62,12 +62,12 @@ impl Program { .map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?; // Get outputs - let ProgramOutput { post_states, .. } = session_info + let program_output = session_info .journal .decode() .map_err(|e| NssaError::ProgramExecutionFailed(e.to_string()))?; - Ok(post_states) + Ok(program_output) } /// Writes inputs to `env_builder` in the order expected by the programs @@ -107,11 +107,11 @@ impl Program { #[cfg(test)] mod tests { - use nssa_core::account::{Account, AccountId, AccountWithMetadata}; - use program_methods::{ + use crate::program_methods::{ AUTHENTICATED_TRANSFER_ELF, AUTHENTICATED_TRANSFER_ID, PINATA_ELF, PINATA_ID, TOKEN_ELF, TOKEN_ID, }; + use nssa_core::account::{Account, AccountId, AccountWithMetadata}; use crate::program::Program; @@ -195,6 +195,15 @@ mod tests { elf: BURNER_ELF.to_vec(), } } + + pub fn chain_caller() -> Self { + use test_program_methods::{CHAIN_CALLER_ELF, CHAIN_CALLER_ID}; + + Program { + id: CHAIN_CALLER_ID, + elf: CHAIN_CALLER_ELF.to_vec(), + } + } } #[test] @@ -221,12 +230,12 @@ mod tests { balance: balance_to_move, ..Account::default() }; - let [sender_post, recipient_post] = program + let program_output = program .execute(&[sender, recipient], &instruction_data) - .unwrap() - .try_into() .unwrap(); + let [sender_post, recipient_post] = program_output.post_states.try_into().unwrap(); + assert_eq!(sender_post, expected_sender_post); assert_eq!(recipient_post, expected_recipient_post); } diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index b0b8f73..d118d0c 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet}; use nssa_core::{ account::{Account, AccountWithMetadata}, address::Address, - program::validate_execution, + program::{DEFAULT_PROGRAM_ID, validate_execution}, }; use sha2::{Digest, digest::FixedOutput}; @@ -18,6 +18,7 @@ pub struct PublicTransaction { message: Message, witness_set: WitnessSet, } +const MAX_NUMBER_CHAINED_CALLS: usize = 10; impl PublicTransaction { pub fn new(message: Message, witness_set: WitnessSet) -> Self { @@ -88,7 +89,7 @@ impl PublicTransaction { } // Build pre_states for execution - let pre_states: Vec<_> = message + let mut input_pre_states: Vec<_> = message .addresses .iter() .map(|address| { @@ -100,21 +101,86 @@ impl PublicTransaction { }) .collect(); - // Check the `program_id` corresponds to a deployed program - let Some(program) = state.programs().get(&message.program_id) else { - return Err(NssaError::InvalidInput("Unknown program".into())); - }; + let mut state_diff: HashMap = HashMap::new(); - // // Execute program - let post_states = program.execute(&pre_states, &message.instruction_data)?; + let mut program_id = message.program_id; + let mut instruction_data = message.instruction_data.clone(); - // Verify execution corresponds to a well-behaved program. - // See the # Programs section for the definition of the `validate_execution` method. - if !validate_execution(&pre_states, &post_states, message.program_id) { - return Err(NssaError::InvalidProgramBehavior); + for _i in 0..MAX_NUMBER_CHAINED_CALLS { + // Check the `program_id` corresponds to a deployed program + let Some(program) = state.programs().get(&program_id) else { + return Err(NssaError::InvalidInput("Unknown program".into())); + }; + + let mut program_output = program.execute(&input_pre_states, &instruction_data)?; + + // This check is equivalent to checking that the program output pre_states coinicide + // with the values in the public state or with any modifications to those values + // during the chain of calls. + if input_pre_states != program_output.pre_states { + return Err(NssaError::InvalidProgramBehavior); + } + + // Verify execution corresponds to a well-behaved program. + // See the # Programs section for the definition of the `validate_execution` method. + if !validate_execution( + &program_output.pre_states, + &program_output.post_states, + program_id, + ) { + return Err(NssaError::InvalidProgramBehavior); + } + + // The invoked program claims the accounts with default program id. + for post in program_output.post_states.iter_mut() { + if post.program_owner == DEFAULT_PROGRAM_ID { + post.program_owner = program_id; + } + } + + // Update the state diff + for (pre, post) in program_output + .pre_states + .iter() + .zip(program_output.post_states.iter()) + { + state_diff.insert(pre.account_id, post.clone()); + } + + if let Some(next_chained_call) = program_output.chained_call { + program_id = next_chained_call.program_id; + instruction_data = next_chained_call.instruction_data; + + // Build post states with metadata for next call + let mut post_states_with_metadata = Vec::new(); + for (pre, post) in program_output + .pre_states + .iter() + .zip(program_output.post_states) + { + let mut post_with_metadata = pre.clone(); + post_with_metadata.account = post.clone(); + post_states_with_metadata.push(post_with_metadata); + } + + input_pre_states = next_chained_call + .account_indices + .iter() + .map(|&i| { + post_states_with_metadata + .get(i) + .ok_or_else(|| { + NssaError::InvalidInput("Invalid account indices".into()) + }) + .cloned() + }) + .collect::, NssaError>>()?; + } else { + break; + }; } - Ok(message.addresses.iter().cloned().zip(post_states).collect()) + Ok(state_diff) } } diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 83183f5..4120824 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -6,9 +6,7 @@ use crate::{ }; use nssa_core::{ Commitment, CommitmentSetDigest, DUMMY_COMMITMENT, MembershipProof, Nullifier, - account::Account, - address::Address, - program::{DEFAULT_PROGRAM_ID, ProgramId}, + account::Account, address::Address, program::ProgramId, }; use std::collections::{HashMap, HashSet}; @@ -114,10 +112,6 @@ impl V02State { let current_account = self.get_account_by_address_mut(address); *current_account = post; - // The invoked program claims the accounts with default program id. - if current_account.program_owner == DEFAULT_PROGRAM_ID { - current_account.program_owner = tx.message().program_id; - } } for address in tx.signer_addresses() { @@ -263,6 +257,7 @@ pub mod tests { Commitment, Nullifier, NullifierPublicKey, NullifierSecretKey, SharedSecretKey, account::{Account, AccountId, AccountWithMetadata, Nonce}, encryption::{EphemeralPublicKey, IncomingViewingPublicKey, Scalar}, + program::ProgramId, }; fn transfer_transaction( @@ -436,7 +431,7 @@ pub mod tests { } #[test] - fn transition_from_chained_authenticated_transfer_program_invocations() { + fn transition_from_sequence_of_authenticated_transfer_program_invocations() { let key1 = PrivateKey::try_new([8; 32]).unwrap(); let address1 = Address::from(&PublicKey::new_from_private_key(&key1)); let key2 = PrivateKey::try_new([2; 32]).unwrap(); @@ -475,6 +470,7 @@ pub mod tests { self.insert_program(Program::data_changer()); self.insert_program(Program::minter()); self.insert_program(Program::burner()); + self.insert_program(Program::chain_caller()); self } @@ -2045,4 +2041,80 @@ pub mod tests { assert!(matches!(result, Err(NssaError::CircuitProvingError(_)))); } + + #[test] + fn test_claiming_mechanism() { + let program = Program::authenticated_transfer_program(); + let key = PrivateKey::try_new([1; 32]).unwrap(); + let address = Address::from(&PublicKey::new_from_private_key(&key)); + let initial_balance = 100; + let initial_data = [(address, initial_balance)]; + let mut state = + V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs(); + let from = address; + let from_key = key; + let to = Address::new([2; 32]); + let amount: u128 = 37; + + // Check the recipient is an uninitialized account + assert_eq!(state.get_account_by_address(&to), Account::default()); + + let expected_recipient_post = Account { + program_owner: program.id(), + balance: amount, + ..Account::default() + }; + + let message = + public_transaction::Message::try_new(program.id(), vec![from, to], vec![0], amount) + .unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&from_key]); + let tx = PublicTransaction::new(message, witness_set); + + state.transition_from_public_transaction(&tx).unwrap(); + + let recipient_post = state.get_account_by_address(&to); + + assert_eq!(recipient_post, expected_recipient_post); + } + + #[test] + fn test_chained_call() { + let program = Program::chain_caller(); + let key = PrivateKey::try_new([1; 32]).unwrap(); + let address = Address::from(&PublicKey::new_from_private_key(&key)); + let initial_balance = 100; + let initial_data = [(address, initial_balance)]; + let mut state = + V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs(); + let from = address; + let from_key = key; + let to = Address::new([2; 32]); + let amount: u128 = 37; + let instruction: (u128, ProgramId) = + (amount, Program::authenticated_transfer_program().id()); + + let expected_to_post = Account { + program_owner: Program::chain_caller().id(), + balance: amount, + ..Account::default() + }; + + let message = public_transaction::Message::try_new( + program.id(), + vec![to, from], //The chain_caller program permutes the account order in the chain call + vec![0], + instruction, + ) + .unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&from_key]); + let tx = PublicTransaction::new(message, witness_set); + + state.transition_from_public_transaction(&tx).unwrap(); + + let from_post = state.get_account_by_address(&from); + let to_post = state.get_account_by_address(&to); + assert_eq!(from_post.balance, initial_balance - amount); + assert_eq!(to_post, expected_to_post); + } } diff --git a/nssa/test_program_methods/guest/Cargo.lock b/nssa/test_program_methods/guest/Cargo.lock index 8cb2bec..d7e5b67 100644 --- a/nssa/test_program_methods/guest/Cargo.lock +++ b/nssa/test_program_methods/guest/Cargo.lock @@ -1824,6 +1824,8 @@ name = "programs" version = "0.1.0" dependencies = [ "nssa-core", + "risc0-zkvm", + "serde", ] [[package]] diff --git a/nssa/test_program_methods/guest/Cargo.toml b/nssa/test_program_methods/guest/Cargo.toml index 2289292..9e5f543 100644 --- a/nssa/test_program_methods/guest/Cargo.toml +++ b/nssa/test_program_methods/guest/Cargo.toml @@ -6,4 +6,6 @@ edition = "2024" [workspace] [dependencies] +risc0-zkvm = { version = "3.0.3", features = ['std'] } nssa-core = { path = "../../core" } +serde = { version = "1.0.219", default-features = false } diff --git a/nssa/test_program_methods/guest/src/bin/chain_caller.rs b/nssa/test_program_methods/guest/src/bin/chain_caller.rs new file mode 100644 index 0000000..dfd77b1 --- /dev/null +++ b/nssa/test_program_methods/guest/src/bin/chain_caller.rs @@ -0,0 +1,34 @@ +use nssa_core::program::{ + ChainedCall, ProgramId, ProgramInput, read_nssa_inputs, write_nssa_outputs_with_chained_call, +}; +use risc0_zkvm::serde::to_vec; + +type Instruction = (u128, ProgramId); + +/// A program that calls another program. +/// It permutes the order of the input accounts on the subsequent call +fn main() { + let ProgramInput { + pre_states, + instruction: (balance, program_id), + } = read_nssa_inputs::(); + + let [sender_pre, receiver_pre] = match pre_states.try_into() { + Ok(array) => array, + Err(_) => return, + }; + + let instruction_data = to_vec(&balance).unwrap(); + + let chained_call = Some(ChainedCall { + program_id, + instruction_data, + account_indices: vec![1, 0], // <- Account order permutation here + }); + + write_nssa_outputs_with_chained_call( + vec![sender_pre.clone(), receiver_pre.clone()], + vec![sender_pre.account, receiver_pre.account], + chained_call, + ); +}