diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 6582984..3ecee30 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -102,9 +102,14 @@ pub fn validate_execution( { return false; } + + // 6. If a post state has default program owner, the pre state must have been a default account + if post.program_owner == DEFAULT_PROGRAM_ID && pre.account != Account::default() { + return false; + } } - // 6. Total balance is preserved + // 7. Total balance is preserved let total_balance_pre_states: u128 = pre_states.iter().map(|pre| pre.account.balance).sum(); let total_balance_post_states: u128 = post_states.iter().map(|post| post.balance).sum(); if total_balance_pre_states != total_balance_post_states { diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 70d22a3..900a476 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet}; use nssa_core::{ account::{Account, AccountWithMetadata}, address::Address, - program::{ChainedCall, validate_execution}, + program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution}, }; use sha2::{Digest, digest::FixedOutput}; @@ -131,27 +131,34 @@ impl PublicTransaction { }) .collect::, NssaError>>()?; - let program_output = + let mut program_output = program.execute(&pre_states_chained_call, &chained_call.instruction_data)?; // Verify execution corresponds to a well-behaved program. // See the # Programs section for the definition of the `validate_execution` method. - if validate_execution( + if !validate_execution( &program_output.pre_states, &program_output.post_states, chained_call.program_id, ) { - for (pre, post) in program_output - .pre_states - .iter() - .zip(program_output.post_states) - { - state_diff.insert(pre.account_id, post); - } - } else { return Err(NssaError::InvalidProgramBehavior); } + for post in program_output.post_states.iter_mut() { + // The invoked program claims the accounts with default program id. + if post.program_owner == DEFAULT_PROGRAM_ID { + post.program_owner = chained_call.program_id; + } + } + + for (pre, post) in program_output + .pre_states + .iter() + .zip(program_output.post_states) + { + state_diff.insert(pre.account_id, post); + } + if let Some(next_chained_call) = program_output.chained_call { chained_call = next_chained_call; } else { diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 5a0b87b..2119929 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -8,7 +8,7 @@ use nssa_core::{ Commitment, CommitmentSetDigest, DUMMY_COMMITMENT, MembershipProof, Nullifier, account::Account, address::Address, - program::{DEFAULT_PROGRAM_ID, ProgramId}, + program::ProgramId, }; use std::collections::{HashMap, HashSet}; @@ -114,10 +114,6 @@ impl V02State { let current_account = self.get_account_by_address_mut(address); *current_account = post; - // The invoked program claims the accounts with default program id. - if current_account.program_owner == DEFAULT_PROGRAM_ID { - current_account.program_owner = tx.message().program_id; - } } for address in tx.signer_addresses() { @@ -437,7 +433,7 @@ pub mod tests { } #[test] - fn transition_from_chained_authenticated_transfer_program_invocations() { + fn transition_from_sequence_of_authenticated_transfer_program_invocations() { let key1 = PrivateKey::try_new([8; 32]).unwrap(); let address1 = Address::from(&PublicKey::new_from_private_key(&key1)); let key2 = PrivateKey::try_new([2; 32]).unwrap(); @@ -2048,6 +2044,42 @@ pub mod tests { assert!(matches!(result, Err(NssaError::CircuitProvingError(_)))); } + #[test] + fn test_claiming_mechanism() { + let program = Program::authenticated_transfer_program(); + let key = PrivateKey::try_new([1; 32]).unwrap(); + let address = Address::from(&PublicKey::new_from_private_key(&key)); + let initial_balance = 100; + let initial_data = [(address, initial_balance)]; + let mut state = + V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs(); + let from = address; + let from_key = key; + let to = Address::new([2; 32]); + let amount: u128 = 37; + + // Check the recipient is an uninitialized account + assert_eq!(state.get_account_by_address(&to), Account::default()); + + let expected_recipient_post = Account { + program_owner: program.id(), + balance: amount, + ..Account::default() + }; + + let message = + public_transaction::Message::try_new(program.id(), vec![from, to], vec![0], amount) + .unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&from_key]); + let tx = PublicTransaction::new(message, witness_set); + + state.transition_from_public_transaction(&tx).unwrap(); + + let recipient_post = state.get_account_by_address(&to); + + assert_eq!(recipient_post, expected_recipient_post); + } + #[test] fn test_chained_call() { let program = Program::chain_caller(); @@ -2064,9 +2096,15 @@ pub mod tests { let instruction: (u128, ProgramId) = (amount, Program::authenticated_transfer_program().id()); + let expected_to_post = Account { + program_owner: Program::authenticated_transfer_program().id(), + balance: amount, + ..Account::default() + }; + let message = public_transaction::Message::try_new( program.id(), - vec![to, from], + vec![to, from], //The chain_caller program permutes the account order in the chain call vec![0], instruction, ) @@ -2079,6 +2117,6 @@ pub mod tests { let from_post = state.get_account_by_address(&from); let to_post = state.get_account_by_address(&to); assert_eq!(from_post.balance, initial_balance - amount); - assert_eq!(to_post.balance, amount); + assert_eq!(to_post, expected_to_post); } } diff --git a/nssa/test_program_methods/guest/src/bin/chain_caller.rs b/nssa/test_program_methods/guest/src/bin/chain_caller.rs index 321d032..dfd77b1 100644 --- a/nssa/test_program_methods/guest/src/bin/chain_caller.rs +++ b/nssa/test_program_methods/guest/src/bin/chain_caller.rs @@ -5,6 +5,8 @@ use risc0_zkvm::serde::to_vec; type Instruction = (u128, ProgramId); +/// A program that calls another program. +/// It permutes the order of the input accounts on the subsequent call fn main() { let ProgramInput { pre_states, @@ -21,7 +23,7 @@ fn main() { let chained_call = Some(ChainedCall { program_id, instruction_data, - account_indices: vec![1, 0], + account_indices: vec![1, 0], // <- Account order permutation here }); write_nssa_outputs_with_chained_call(