diff --git a/Cargo.lock b/Cargo.lock index f225074a..9e1d157c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7846,6 +7846,7 @@ dependencies = [ "clock_core", "nssa_core", "risc0-zkvm", + "serde", ] [[package]] diff --git a/artifacts/program_methods/amm.bin b/artifacts/program_methods/amm.bin index 266a26c6..775ec45f 100644 Binary files a/artifacts/program_methods/amm.bin and b/artifacts/program_methods/amm.bin differ diff --git a/artifacts/program_methods/associated_token_account.bin b/artifacts/program_methods/associated_token_account.bin index 9b2debdb..917e0dc5 100644 Binary files a/artifacts/program_methods/associated_token_account.bin and b/artifacts/program_methods/associated_token_account.bin differ diff --git a/artifacts/program_methods/authenticated_transfer.bin b/artifacts/program_methods/authenticated_transfer.bin index b2714e31..cdce17b9 100644 Binary files a/artifacts/program_methods/authenticated_transfer.bin and b/artifacts/program_methods/authenticated_transfer.bin differ diff --git a/artifacts/program_methods/clock.bin b/artifacts/program_methods/clock.bin index ad40b99b..37a4d30f 100644 Binary files a/artifacts/program_methods/clock.bin and b/artifacts/program_methods/clock.bin differ diff --git a/artifacts/program_methods/pinata.bin b/artifacts/program_methods/pinata.bin index ac0921ec..e18d5c2c 100644 Binary files a/artifacts/program_methods/pinata.bin and b/artifacts/program_methods/pinata.bin differ diff --git a/artifacts/program_methods/pinata_token.bin b/artifacts/program_methods/pinata_token.bin index abf68653..f2115a68 100644 Binary files a/artifacts/program_methods/pinata_token.bin and b/artifacts/program_methods/pinata_token.bin differ diff --git a/artifacts/program_methods/privacy_preserving_circuit.bin b/artifacts/program_methods/privacy_preserving_circuit.bin index ca474164..21cf0ddb 100644 Binary files a/artifacts/program_methods/privacy_preserving_circuit.bin and b/artifacts/program_methods/privacy_preserving_circuit.bin differ diff --git a/artifacts/program_methods/token.bin b/artifacts/program_methods/token.bin index 6c83eb91..ebb374f3 100644 Binary files a/artifacts/program_methods/token.bin and b/artifacts/program_methods/token.bin differ diff --git a/artifacts/test_program_methods/burner.bin b/artifacts/test_program_methods/burner.bin index 9f519eef..e2fea8bd 100644 Binary files a/artifacts/test_program_methods/burner.bin and b/artifacts/test_program_methods/burner.bin differ diff --git a/artifacts/test_program_methods/chain_caller.bin b/artifacts/test_program_methods/chain_caller.bin index e3eaca78..d6670787 100644 Binary files a/artifacts/test_program_methods/chain_caller.bin and b/artifacts/test_program_methods/chain_caller.bin differ diff --git a/artifacts/test_program_methods/changer_claimer.bin b/artifacts/test_program_methods/changer_claimer.bin index b2b8895b..47c4200e 100644 Binary files a/artifacts/test_program_methods/changer_claimer.bin and b/artifacts/test_program_methods/changer_claimer.bin differ diff --git a/artifacts/test_program_methods/claimer.bin b/artifacts/test_program_methods/claimer.bin index b97cf30e..8b8bc140 100644 Binary files a/artifacts/test_program_methods/claimer.bin and b/artifacts/test_program_methods/claimer.bin differ diff --git a/artifacts/test_program_methods/clock_chain_caller.bin b/artifacts/test_program_methods/clock_chain_caller.bin index 238225e4..2faa9b69 100644 Binary files a/artifacts/test_program_methods/clock_chain_caller.bin and b/artifacts/test_program_methods/clock_chain_caller.bin differ diff --git a/artifacts/test_program_methods/data_changer.bin b/artifacts/test_program_methods/data_changer.bin index 02b0d1d7..2ade0385 100644 Binary files a/artifacts/test_program_methods/data_changer.bin and b/artifacts/test_program_methods/data_changer.bin differ diff --git a/artifacts/test_program_methods/extra_output.bin b/artifacts/test_program_methods/extra_output.bin index e29d9557..d0095d2b 100644 Binary files a/artifacts/test_program_methods/extra_output.bin and b/artifacts/test_program_methods/extra_output.bin differ diff --git a/artifacts/test_program_methods/flash_swap_callback.bin b/artifacts/test_program_methods/flash_swap_callback.bin new file mode 100644 index 00000000..f259c5b3 Binary files /dev/null and b/artifacts/test_program_methods/flash_swap_callback.bin differ diff --git a/artifacts/test_program_methods/flash_swap_initiator.bin b/artifacts/test_program_methods/flash_swap_initiator.bin new file mode 100644 index 00000000..f1b67504 Binary files /dev/null and b/artifacts/test_program_methods/flash_swap_initiator.bin differ diff --git a/artifacts/test_program_methods/malicious_authorization_changer.bin b/artifacts/test_program_methods/malicious_authorization_changer.bin index 60d0ee2e..75df8bec 100644 Binary files a/artifacts/test_program_methods/malicious_authorization_changer.bin and b/artifacts/test_program_methods/malicious_authorization_changer.bin differ diff --git a/artifacts/test_program_methods/malicious_caller_program_id.bin b/artifacts/test_program_methods/malicious_caller_program_id.bin new file mode 100644 index 00000000..9907ba58 Binary files /dev/null and b/artifacts/test_program_methods/malicious_caller_program_id.bin differ diff --git a/artifacts/test_program_methods/malicious_self_program_id.bin b/artifacts/test_program_methods/malicious_self_program_id.bin new file mode 100644 index 00000000..b530a0b3 Binary files /dev/null and b/artifacts/test_program_methods/malicious_self_program_id.bin differ diff --git a/artifacts/test_program_methods/minter.bin b/artifacts/test_program_methods/minter.bin index 9f0f9731..392aa2fa 100644 Binary files a/artifacts/test_program_methods/minter.bin and b/artifacts/test_program_methods/minter.bin differ diff --git a/artifacts/test_program_methods/missing_output.bin b/artifacts/test_program_methods/missing_output.bin index d8312323..92998b57 100644 Binary files a/artifacts/test_program_methods/missing_output.bin and b/artifacts/test_program_methods/missing_output.bin differ diff --git a/artifacts/test_program_methods/modified_transfer.bin b/artifacts/test_program_methods/modified_transfer.bin index 2fff50cb..65475b18 100644 Binary files a/artifacts/test_program_methods/modified_transfer.bin and b/artifacts/test_program_methods/modified_transfer.bin differ diff --git a/artifacts/test_program_methods/nonce_changer.bin b/artifacts/test_program_methods/nonce_changer.bin index f9a50a54..809ed4ec 100644 Binary files a/artifacts/test_program_methods/nonce_changer.bin and b/artifacts/test_program_methods/nonce_changer.bin differ diff --git a/artifacts/test_program_methods/noop.bin b/artifacts/test_program_methods/noop.bin index 39e1161b..9c2fa8bc 100644 Binary files a/artifacts/test_program_methods/noop.bin and b/artifacts/test_program_methods/noop.bin differ diff --git a/artifacts/test_program_methods/pinata_cooldown.bin b/artifacts/test_program_methods/pinata_cooldown.bin index c735e157..36e60f9c 100644 Binary files a/artifacts/test_program_methods/pinata_cooldown.bin and b/artifacts/test_program_methods/pinata_cooldown.bin differ diff --git a/artifacts/test_program_methods/program_owner_changer.bin b/artifacts/test_program_methods/program_owner_changer.bin index 45ada93d..4dbb34b8 100644 Binary files a/artifacts/test_program_methods/program_owner_changer.bin and b/artifacts/test_program_methods/program_owner_changer.bin differ diff --git a/artifacts/test_program_methods/simple_balance_transfer.bin b/artifacts/test_program_methods/simple_balance_transfer.bin index 06e575a6..df9bee1d 100644 Binary files a/artifacts/test_program_methods/simple_balance_transfer.bin and b/artifacts/test_program_methods/simple_balance_transfer.bin differ diff --git a/artifacts/test_program_methods/time_locked_transfer.bin b/artifacts/test_program_methods/time_locked_transfer.bin index b6267617..8b3da3ea 100644 Binary files a/artifacts/test_program_methods/time_locked_transfer.bin and b/artifacts/test_program_methods/time_locked_transfer.bin differ diff --git a/artifacts/test_program_methods/validity_window.bin b/artifacts/test_program_methods/validity_window.bin index 6fd4c787..009bb965 100644 Binary files a/artifacts/test_program_methods/validity_window.bin and b/artifacts/test_program_methods/validity_window.bin differ diff --git a/artifacts/test_program_methods/validity_window_chain_caller.bin b/artifacts/test_program_methods/validity_window_chain_caller.bin index 875131f1..cf9e8af5 100644 Binary files a/artifacts/test_program_methods/validity_window_chain_caller.bin and b/artifacts/test_program_methods/validity_window_chain_caller.bin differ diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world.rs b/examples/program_deployment/methods/guest/src/bin/hello_world.rs index ea2edd95..3e91db0e 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world.rs @@ -20,6 +20,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: greeting, }, @@ -53,6 +54,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs b/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs index 3f369fa7..70dfa2ae 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs @@ -20,6 +20,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: greeting, }, @@ -60,6 +61,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs b/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs index 57a2190c..4289349b 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs @@ -67,6 +67,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (function_id, data), }, @@ -86,5 +87,12 @@ fn main() { // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + pre_states, + post_states, + ) + .write(); } diff --git a/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs b/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs index 22098b7a..716e5c29 100644 --- a/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs +++ b/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs @@ -28,6 +28,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (), }, @@ -58,6 +59,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs b/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs index 2ae65ec7..5ec9aaab 100644 --- a/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs +++ b/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs @@ -34,6 +34,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (), }, @@ -71,6 +72,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 057c8238..a08fb2b4 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -17,6 +17,7 @@ pub type ProgramId = [u32; 8]; pub type InstructionData = Vec; pub struct ProgramInput { pub self_program_id: ProgramId, + pub caller_program_id: Option, pub pre_states: Vec, pub instruction: T, } @@ -284,6 +285,9 @@ pub struct InvalidWindow; pub struct ProgramOutput { /// The program ID of the program that produced this output. pub self_program_id: ProgramId, + /// The program ID of the caller that invoked this program via a chained call, + /// or `None` if this is a top-level call. + pub caller_program_id: Option, /// The instruction data the program received to produce this output. pub instruction_data: InstructionData, /// The account pre states the program received to produce this output. @@ -301,12 +305,14 @@ pub struct ProgramOutput { impl ProgramOutput { pub const fn new( self_program_id: ProgramId, + caller_program_id: Option, instruction_data: InstructionData, pre_states: Vec, post_states: Vec, ) -> Self { Self { self_program_id, + caller_program_id, instruction_data, pre_states, post_states, @@ -421,12 +427,14 @@ pub fn compute_authorized_pdas( #[must_use] pub fn read_nssa_inputs() -> (ProgramInput, InstructionData) { let self_program_id: ProgramId = env::read(); + let caller_program_id: Option = env::read(); let pre_states: Vec = env::read(); let instruction_words: InstructionData = env::read(); let instruction = T::deserialize(&mut Deserializer::new(instruction_words.as_ref())).unwrap(); ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction, }, @@ -627,7 +635,7 @@ mod tests { #[test] fn program_output_try_with_block_validity_window_range() { - let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .try_with_block_validity_window(10_u64..100) .unwrap(); assert_eq!(output.block_validity_window.start(), Some(10)); @@ -636,7 +644,7 @@ mod tests { #[test] fn program_output_with_block_validity_window_range_from() { - let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .with_block_validity_window(10_u64..); assert_eq!(output.block_validity_window.start(), Some(10)); assert_eq!(output.block_validity_window.end(), None); @@ -644,7 +652,7 @@ mod tests { #[test] fn program_output_with_block_validity_window_range_to() { - let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .with_block_validity_window(..100_u64); assert_eq!(output.block_validity_window.start(), None); assert_eq!(output.block_validity_window.end(), Some(100)); @@ -652,7 +660,7 @@ mod tests { #[test] fn program_output_try_with_block_validity_window_empty_range_fails() { - let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .try_with_block_validity_window(5_u64..5); assert!(result.is_err()); } diff --git a/nssa/src/privacy_preserving_transaction/circuit.rs b/nssa/src/privacy_preserving_transaction/circuit.rs index 48c59ce7..6c174450 100644 --- a/nssa/src/privacy_preserving_transaction/circuit.rs +++ b/nssa/src/privacy_preserving_transaction/circuit.rs @@ -87,15 +87,16 @@ pub fn execute_and_prove( pda_seeds: vec![], }; - let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program)]); + let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program, None)]); let mut chain_calls_counter = 0; - while let Some((chained_call, program)) = chained_calls.pop_front() { + while let Some((chained_call, program, caller_program_id)) = chained_calls.pop_front() { if chain_calls_counter >= MAX_NUMBER_CHAINED_CALLS { return Err(NssaError::MaxChainedCallsDepthExceeded); } let inner_receipt = execute_and_prove_program( program, + caller_program_id, &chained_call.pre_states, &chained_call.instruction_data, )?; @@ -115,7 +116,7 @@ pub fn execute_and_prove( let next_program = dependencies .get(&new_call.program_id) .ok_or(NssaError::InvalidProgramBehavior)?; - chained_calls.push_front((new_call, next_program)); + chained_calls.push_front((new_call, next_program, Some(chained_call.program_id))); } chain_calls_counter = chain_calls_counter @@ -153,12 +154,19 @@ pub fn execute_and_prove( fn execute_and_prove_program( program: &Program, + caller_program_id: Option, pre_states: &[AccountWithMetadata], instruction_data: &InstructionData, ) -> Result { // Write inputs to the program let mut env_builder = ExecutorEnv::builder(); - Program::write_inputs(program.id(), pre_states, instruction_data, &mut env_builder)?; + Program::write_inputs( + program.id(), + caller_program_id, + pre_states, + instruction_data, + &mut env_builder, + )?; let env = env_builder.build().unwrap(); // Prove the program diff --git a/nssa/src/program.rs b/nssa/src/program.rs index ed0e90ad..3dd37bce 100644 --- a/nssa/src/program.rs +++ b/nssa/src/program.rs @@ -53,13 +53,20 @@ impl Program { pub(crate) fn execute( &self, + caller_program_id: Option, pre_states: &[AccountWithMetadata], instruction_data: &InstructionData, ) -> Result { // Write inputs to the program let mut env_builder = ExecutorEnv::builder(); env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION)); - Self::write_inputs(self.id, pre_states, instruction_data, &mut env_builder)?; + Self::write_inputs( + self.id, + caller_program_id, + pre_states, + instruction_data, + &mut env_builder, + )?; let env = env_builder.build().unwrap(); // Execute the program (without proving) @@ -80,6 +87,7 @@ impl Program { /// Writes inputs to `env_builder` in the order expected by the programs. pub(crate) fn write_inputs( program_id: ProgramId, + caller_program_id: Option, pre_states: &[AccountWithMetadata], instruction_data: &[u32], env_builder: &mut ExecutorEnvBuilder, @@ -87,6 +95,9 @@ impl Program { env_builder .write(&program_id) .map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?; + env_builder + .write(&caller_program_id) + .map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?; let pre_states = pre_states.to_vec(); env_builder .write(&pre_states) @@ -320,6 +331,34 @@ mod tests { Self::new(VALIDITY_WINDOW_CHAIN_CALLER_ELF.to_vec()).unwrap() } + #[must_use] + pub fn flash_swap_initiator() -> Self { + use test_program_methods::FLASH_SWAP_INITIATOR_ELF; + Self::new(FLASH_SWAP_INITIATOR_ELF.to_vec()) + .expect("flash_swap_initiator must be a valid Risc0 program") + } + + #[must_use] + pub fn flash_swap_callback() -> Self { + use test_program_methods::FLASH_SWAP_CALLBACK_ELF; + Self::new(FLASH_SWAP_CALLBACK_ELF.to_vec()) + .expect("flash_swap_callback must be a valid Risc0 program") + } + + #[must_use] + pub fn malicious_self_program_id() -> Self { + use test_program_methods::MALICIOUS_SELF_PROGRAM_ID_ELF; + Self::new(MALICIOUS_SELF_PROGRAM_ID_ELF.to_vec()) + .expect("malicious_self_program_id must be a valid Risc0 program") + } + + #[must_use] + pub fn malicious_caller_program_id() -> Self { + use test_program_methods::MALICIOUS_CALLER_PROGRAM_ID_ELF; + Self::new(MALICIOUS_CALLER_PROGRAM_ID_ELF.to_vec()) + .expect("malicious_caller_program_id must be a valid Risc0 program") + } + #[must_use] pub fn time_locked_transfer() -> Self { use test_program_methods::TIME_LOCKED_TRANSFER_ELF; @@ -358,7 +397,7 @@ mod tests { ..Account::default() }; let program_output = program - .execute(&[sender, recipient], &instruction_data) + .execute(None, &[sender, recipient], &instruction_data) .unwrap(); let [sender_post, recipient_post] = program_output.post_states.try_into().unwrap(); diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 3024fe60..7753e1a3 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -400,6 +400,10 @@ pub mod tests { self.insert_program(Program::claimer()); self.insert_program(Program::changer_claimer()); self.insert_program(Program::validity_window()); + self.insert_program(Program::flash_swap_initiator()); + self.insert_program(Program::flash_swap_callback()); + self.insert_program(Program::malicious_self_program_id()); + self.insert_program(Program::malicious_caller_program_id()); self.insert_program(Program::time_locked_transfer()); self.insert_program(Program::pinata_cooldown()); self @@ -478,6 +482,28 @@ pub mod tests { } } + // ── Flash Swap types (mirrors of guest types for host-side serialisation) ── + + #[derive(serde::Serialize, serde::Deserialize)] + struct CallbackInstruction { + return_funds: bool, + token_program_id: ProgramId, + amount: u128, + } + + #[derive(serde::Serialize, serde::Deserialize)] + enum FlashSwapInstruction { + Initiate { + token_program_id: ProgramId, + callback_program_id: ProgramId, + amount_out: u128, + callback_instruction_data: Vec, + }, + InvariantCheck { + min_vault_balance: u128, + }, + } + fn transfer_transaction( from: AccountId, from_key: &PrivateKey, @@ -497,6 +523,23 @@ pub mod tests { PublicTransaction::new(message, witness_set) } + fn build_flash_swap_tx( + initiator: &Program, + vault_id: AccountId, + receiver_id: AccountId, + instruction: FlashSwapInstruction, + ) -> PublicTransaction { + let message = public_transaction::Message::try_new( + initiator.id(), + vec![vault_id, receiver_id], + vec![], // no signers — vault is PDA-authorised + instruction, + ) + .unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[]); + PublicTransaction::new(message, witness_set) + } + #[test] fn new_with_genesis() { let key1 = PrivateKey::try_new([1; 32]).unwrap(); @@ -3877,4 +3920,242 @@ pub mod tests { let state_from_bytes: V03State = borsh::from_slice(&bytes).unwrap(); assert_eq!(state, state_from_bytes); } + + #[test] + fn flash_swap_successful() { + let initiator = Program::flash_swap_initiator(); + let callback = Program::flash_swap_callback(); + let token = Program::authenticated_transfer_program(); + + let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32]))); + let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32]))); + + let initial_balance: u128 = 1000; + let amount_out: u128 = 100; + + let vault_account = Account { + program_owner: token.id(), + balance: initial_balance, + ..Account::default() + }; + let receiver_account = Account { + program_owner: token.id(), + balance: 0, + ..Account::default() + }; + + let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs(); + state.force_insert_account(vault_id, vault_account); + state.force_insert_account(receiver_id, receiver_account); + + // Callback instruction: return funds + let cb_instruction = CallbackInstruction { + return_funds: true, + token_program_id: token.id(), + amount: amount_out, + }; + let cb_data = Program::serialize_instruction(cb_instruction).unwrap(); + + let instruction = FlashSwapInstruction::Initiate { + token_program_id: token.id(), + callback_program_id: callback.id(), + amount_out, + callback_instruction_data: cb_data, + }; + + let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction); + let result = state.transition_from_public_transaction(&tx, 1, 0); + assert!(result.is_ok(), "flash swap should succeed: {result:?}"); + + // Vault balance restored, receiver back to 0 + assert_eq!(state.get_account_by_id(vault_id).balance, initial_balance); + assert_eq!(state.get_account_by_id(receiver_id).balance, 0); + } + + #[test] + fn flash_swap_callback_keeps_funds_rollback() { + let initiator = Program::flash_swap_initiator(); + let callback = Program::flash_swap_callback(); + let token = Program::authenticated_transfer_program(); + + let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32]))); + let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32]))); + + let initial_balance: u128 = 1000; + let amount_out: u128 = 100; + + let vault_account = Account { + program_owner: token.id(), + balance: initial_balance, + ..Account::default() + }; + let receiver_account = Account { + program_owner: token.id(), + balance: 0, + ..Account::default() + }; + + let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs(); + state.force_insert_account(vault_id, vault_account); + state.force_insert_account(receiver_id, receiver_account); + + // Callback instruction: do NOT return funds + let cb_instruction = CallbackInstruction { + return_funds: false, + token_program_id: token.id(), + amount: amount_out, + }; + let cb_data = Program::serialize_instruction(cb_instruction).unwrap(); + + let instruction = FlashSwapInstruction::Initiate { + token_program_id: token.id(), + callback_program_id: callback.id(), + amount_out, + callback_instruction_data: cb_data, + }; + + let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction); + let result = state.transition_from_public_transaction(&tx, 1, 0); + + // Invariant check fails → entire tx rolls back + assert!( + result.is_err(), + "flash swap should fail when callback keeps funds" + ); + + // State unchanged (rollback) + assert_eq!(state.get_account_by_id(vault_id).balance, initial_balance); + assert_eq!(state.get_account_by_id(receiver_id).balance, 0); + } + + #[test] + fn flash_swap_self_call_targets_correct_program() { + // Zero-amount flash swap: the invariant self-call still runs and succeeds + // because vault balance doesn't decrease. + let initiator = Program::flash_swap_initiator(); + let callback = Program::flash_swap_callback(); + let token = Program::authenticated_transfer_program(); + + let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32]))); + let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32]))); + + let initial_balance: u128 = 1000; + + let vault_account = Account { + program_owner: token.id(), + balance: initial_balance, + ..Account::default() + }; + let receiver_account = Account { + program_owner: token.id(), + balance: 0, + ..Account::default() + }; + + let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs(); + state.force_insert_account(vault_id, vault_account); + state.force_insert_account(receiver_id, receiver_account); + + let cb_instruction = CallbackInstruction { + return_funds: true, + token_program_id: token.id(), + amount: 0, + }; + let cb_data = Program::serialize_instruction(cb_instruction).unwrap(); + + let instruction = FlashSwapInstruction::Initiate { + token_program_id: token.id(), + callback_program_id: callback.id(), + amount_out: 0, + callback_instruction_data: cb_data, + }; + + let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction); + let result = state.transition_from_public_transaction(&tx, 1, 0); + assert!( + result.is_ok(), + "zero-amount flash swap should succeed: {result:?}" + ); + } + + #[test] + fn flash_swap_standalone_invariant_check_rejected() { + // Calling InvariantCheck directly (not as a chained self-call) should fail + // because caller_program_id will be None. + let initiator = Program::flash_swap_initiator(); + let token = Program::authenticated_transfer_program(); + + let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32]))); + + let vault_account = Account { + program_owner: token.id(), + balance: 1000, + ..Account::default() + }; + + let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs(); + state.force_insert_account(vault_id, vault_account); + + let instruction = FlashSwapInstruction::InvariantCheck { + min_vault_balance: 1000, + }; + + let message = public_transaction::Message::try_new( + initiator.id(), + vec![vault_id], + vec![], + instruction, + ) + .unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[]); + let tx = PublicTransaction::new(message, witness_set); + + let result = state.transition_from_public_transaction(&tx, 1, 0); + assert!( + result.is_err(), + "standalone InvariantCheck should be rejected (caller_program_id is None)" + ); + } + + #[test] + fn malicious_self_program_id_rejected_in_public_execution() { + let program = Program::malicious_self_program_id(); + let acc_id = AccountId::new([99; 32]); + let account = Account::default(); + + let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs(); + state.force_insert_account(acc_id, account); + + let message = + public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[]); + let tx = PublicTransaction::new(message, witness_set); + + let result = state.transition_from_public_transaction(&tx, 1, 0); + assert!( + result.is_err(), + "program with wrong self_program_id in output should be rejected" + ); + } + + #[test] + fn malicious_caller_program_id_rejected_in_public_execution() { + let program = Program::malicious_caller_program_id(); + let acc_id = AccountId::new([99; 32]); + let account = Account::default(); + + let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs(); + state.force_insert_account(acc_id, account); + + let message = + public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[]); + let tx = PublicTransaction::new(message, witness_set); + + let result = state.transition_from_public_transaction(&tx, 1, 0); + assert!( + result.is_err(), + "program with spoofed caller_program_id in output should be rejected" + ); + } } diff --git a/nssa/src/validated_state_diff.rs b/nssa/src/validated_state_diff.rs index 71f697dd..9614d1b7 100644 --- a/nssa/src/validated_state_diff.rs +++ b/nssa/src/validated_state_diff.rs @@ -118,8 +118,11 @@ impl ValidatedStateDiff { "Program {:?} pre_states: {:?}, instruction_data: {:?}", chained_call.program_id, chained_call.pre_states, chained_call.instruction_data ); - let mut program_output = - program.execute(&chained_call.pre_states, &chained_call.instruction_data)?; + let mut program_output = program.execute( + caller_program_id, + &chained_call.pre_states, + &chained_call.instruction_data, + )?; debug!( "Program {:?} output: {:?}", chained_call.program_id, program_output @@ -159,6 +162,12 @@ impl ValidatedStateDiff { NssaError::InvalidProgramBehavior ); + // Verify that the program output's caller_program_id matches the actual caller. + ensure!( + program_output.caller_program_id == caller_program_id, + NssaError::InvalidProgramBehavior + ); + // Verify execution corresponds to a well-behaved program. // See the # Programs section for the definition of the `validate_execution` method. ensure!( diff --git a/program_methods/guest/src/bin/amm.rs b/program_methods/guest/src/bin/amm.rs index 59c89742..bce76c63 100644 --- a/program_methods/guest/src/bin/amm.rs +++ b/program_methods/guest/src/bin/amm.rs @@ -15,6 +15,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction, }, @@ -155,6 +156,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, pre_states_clone, post_states, diff --git a/program_methods/guest/src/bin/associated_token_account.rs b/program_methods/guest/src/bin/associated_token_account.rs index 42162ba2..9b155d7f 100644 --- a/program_methods/guest/src/bin/associated_token_account.rs +++ b/program_methods/guest/src/bin/associated_token_account.rs @@ -5,6 +5,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction, }, @@ -59,6 +60,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, pre_states_clone, post_states, diff --git a/program_methods/guest/src/bin/authenticated_transfer.rs b/program_methods/guest/src/bin/authenticated_transfer.rs index d7c68e62..32b69c3a 100644 --- a/program_methods/guest/src/bin/authenticated_transfer.rs +++ b/program_methods/guest/src/bin/authenticated_transfer.rs @@ -68,6 +68,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: balance_to_move, }, @@ -85,5 +86,12 @@ fn main() { _ => panic!("invalid params"), }; - ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + pre_states, + post_states, + ) + .write(); } diff --git a/program_methods/guest/src/bin/clock.rs b/program_methods/guest/src/bin/clock.rs index c06b7336..cb49c384 100644 --- a/program_methods/guest/src/bin/clock.rs +++ b/program_methods/guest/src/bin/clock.rs @@ -40,6 +40,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: timestamp, }, @@ -84,6 +85,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre_01, pre_10, pre_50], vec![post_01, post_10, post_50], diff --git a/program_methods/guest/src/bin/pinata.rs b/program_methods/guest/src/bin/pinata.rs index d6f35ae8..dcc76397 100644 --- a/program_methods/guest/src/bin/pinata.rs +++ b/program_methods/guest/src/bin/pinata.rs @@ -47,6 +47,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: solution, }, @@ -81,6 +82,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pinata, winner], vec![ diff --git a/program_methods/guest/src/bin/pinata_token.rs b/program_methods/guest/src/bin/pinata_token.rs index 5c31af45..1f7ad9da 100644 --- a/program_methods/guest/src/bin/pinata_token.rs +++ b/program_methods/guest/src/bin/pinata_token.rs @@ -53,6 +53,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: solution, }, @@ -99,6 +100,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![ pinata_definition, diff --git a/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/program_methods/guest/src/bin/privacy_preserving_circuit.rs index 48d4b3b7..1d091e1c 100644 --- a/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -114,6 +114,15 @@ impl ExecutionState { "Program output self_program_id does not match chained call program_id" ); + // Verify that the program output's caller_program_id matches the actual caller. + // This prevents a malicious user from privately executing an internal function + // by spoofing caller_program_id (e.g. passing caller_program_id = self_program_id + // to bypass access control checks). + assert_eq!( + program_output.caller_program_id, caller_program_id, + "Program output caller_program_id does not match actual caller" + ); + // Check that the program is well behaved. // See the # Programs section for the definition of the `validate_execution` method. let execution_valid = validate_execution( diff --git a/program_methods/guest/src/bin/token.rs b/program_methods/guest/src/bin/token.rs index 2414a289..68205d77 100644 --- a/program_methods/guest/src/bin/token.rs +++ b/program_methods/guest/src/bin/token.rs @@ -13,6 +13,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction, }, @@ -84,6 +85,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, pre_states_clone, post_states, diff --git a/test_program_methods/guest/Cargo.toml b/test_program_methods/guest/Cargo.toml index 9764bd24..46edeb61 100644 --- a/test_program_methods/guest/Cargo.toml +++ b/test_program_methods/guest/Cargo.toml @@ -12,3 +12,4 @@ nssa_core.workspace = true clock_core.workspace = true risc0-zkvm.workspace = true +serde = { workspace = true, default-features = false } diff --git a/test_program_methods/guest/src/bin/burner.rs b/test_program_methods/guest/src/bin/burner.rs index 06ac9b6b..02be2d38 100644 --- a/test_program_methods/guest/src/bin/burner.rs +++ b/test_program_methods/guest/src/bin/burner.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: balance_to_burn, }, @@ -22,6 +23,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/chain_caller.rs b/test_program_methods/guest/src/bin/chain_caller.rs index e8bf9d6f..5c124bed 100644 --- a/test_program_methods/guest/src/bin/chain_caller.rs +++ b/test_program_methods/guest/src/bin/chain_caller.rs @@ -14,6 +14,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (balance, auth_transfer_id, num_chain_calls, pda_seed), }, @@ -57,6 +58,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![sender_pre.clone(), recipient_pre.clone()], vec![ diff --git a/test_program_methods/guest/src/bin/changer_claimer.rs b/test_program_methods/guest/src/bin/changer_claimer.rs index c1bd886c..6d2b51b4 100644 --- a/test_program_methods/guest/src/bin/changer_claimer.rs +++ b/test_program_methods/guest/src/bin/changer_claimer.rs @@ -7,6 +7,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (data_opt, should_claim), }, @@ -36,6 +37,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![post_state], diff --git a/test_program_methods/guest/src/bin/claimer.rs b/test_program_methods/guest/src/bin/claimer.rs index 27b1ae73..a3a7fb19 100644 --- a/test_program_methods/guest/src/bin/claimer.rs +++ b/test_program_methods/guest/src/bin/claimer.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (), }, @@ -20,6 +21,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![account_post], diff --git a/test_program_methods/guest/src/bin/clock_chain_caller.rs b/test_program_methods/guest/src/bin/clock_chain_caller.rs index 582e228e..cdbe5214 100644 --- a/test_program_methods/guest/src/bin/clock_chain_caller.rs +++ b/test_program_methods/guest/src/bin/clock_chain_caller.rs @@ -15,6 +15,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (clock_program_id, timestamp), }, @@ -33,7 +34,13 @@ fn main() { pda_seeds: vec![], }; - ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states) - .with_chained_calls(vec![chained_call]) - .write(); + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + pre_states, + post_states, + ) + .with_chained_calls(vec![chained_call]) + .write(); } diff --git a/test_program_methods/guest/src/bin/data_changer.rs b/test_program_methods/guest/src/bin/data_changer.rs index ee7cb235..3969d7f6 100644 --- a/test_program_methods/guest/src/bin/data_changer.rs +++ b/test_program_methods/guest/src/bin/data_changer.rs @@ -7,6 +7,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: data, }, @@ -25,6 +26,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new_claimed( diff --git a/test_program_methods/guest/src/bin/extra_output.rs b/test_program_methods/guest/src/bin/extra_output.rs index 924f4d8f..3a5df556 100644 --- a/test_program_methods/guest/src/bin/extra_output.rs +++ b/test_program_methods/guest/src/bin/extra_output.rs @@ -9,6 +9,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, .. }, @@ -23,6 +24,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![ diff --git a/test_program_methods/guest/src/bin/flash_swap_callback.rs b/test_program_methods/guest/src/bin/flash_swap_callback.rs new file mode 100644 index 00000000..251833bb --- /dev/null +++ b/test_program_methods/guest/src/bin/flash_swap_callback.rs @@ -0,0 +1,94 @@ +//! Flash swap callback, the user logic step in the "prep → callback → assert" pattern. +//! +//! # Role +//! +//! This program is called as chained call 2 in the flash swap sequence: +//! 1. Token transfer out (vault → receiver) +//! 2. **This callback** (user logic) +//! 3. Invariant check (assert vault balance restored) +//! +//! In a real flash swap, this would contain the user's arbitrage or other logic. +//! In this test program, it is controlled by `return_funds`: +//! +//! - `return_funds = true`: emits a token transfer (receiver → vault) to return the funds. The +//! invariant check will pass and the transaction will succeed. +//! +//! - `return_funds = false`: emits no transfers. Funds stay with the receiver. The invariant check +//! will fail (vault balance < initial), causing full atomic rollback. This simulates a malicious +//! or buggy callback that does not repay the flash loan. +//! +//! # Note on `caller_program_id` +//! +//! This program does not enforce any access control on `caller_program_id`. +//! It is designed to be called by the flash swap initiator but could in principle be +//! called by any program. In production, a callback would typically verify the caller +//! if it needs to trust the context it is called from. + +use nssa_core::program::{ + AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput, + read_nssa_inputs, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct CallbackInstruction { + /// If true, return the borrowed funds to the vault (happy path). + /// If false, keep the funds (simulates a malicious callback, triggers rollback). + pub return_funds: bool, + pub token_program_id: ProgramId, + pub amount: u128, +} + +fn main() { + let ( + ProgramInput { + self_program_id, + caller_program_id, // not enforced in this callback + pre_states, + instruction, + }, + instruction_words, + ) = read_nssa_inputs::(); + + // pre_states[0] = vault (after transfer out), pre_states[1] = receiver (after transfer out) + let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else { + panic!("Callback requires exactly 2 accounts: vault, receiver"); + }; + + let mut chained_calls = Vec::new(); + + if instruction.return_funds { + // Happy path: return the borrowed funds via a token transfer (receiver → vault). + // The receiver is a PDA of this callback program (seed = [1_u8; 32]). + // Mark the receiver as authorized since it will be PDA-authorized in this chained call. + let mut receiver_authorized = receiver_pre.clone(); + receiver_authorized.is_authorized = true; + let transfer_instruction = risc0_zkvm::serde::to_vec(&instruction.amount) + .expect("transfer instruction serialization"); + + chained_calls.push(ChainedCall { + program_id: instruction.token_program_id, + pre_states: vec![receiver_authorized, vault_pre.clone()], + instruction_data: transfer_instruction, + pda_seeds: vec![PdaSeed::new([1_u8; 32])], + }); + } + // Malicious path (return_funds = false): emit no chained calls. + // The vault balance will not be restored, so the invariant check in the initiator + // will panic, rolling back the entire transaction including the initial transfer out. + + // The callback itself makes no direct state changes, accounts pass through unchanged. + // All mutations go through the token program via chained calls. + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + vec![vault_pre.clone(), receiver_pre.clone()], + vec![ + AccountPostState::new(vault_pre.account), + AccountPostState::new(receiver_pre.account), + ], + ) + .with_chained_calls(chained_calls) + .write(); +} diff --git a/test_program_methods/guest/src/bin/flash_swap_initiator.rs b/test_program_methods/guest/src/bin/flash_swap_initiator.rs new file mode 100644 index 00000000..27d1f317 --- /dev/null +++ b/test_program_methods/guest/src/bin/flash_swap_initiator.rs @@ -0,0 +1,216 @@ +//! Flash swap initiator, demonstrates the "prep → callback → assert" pattern using +//! generalized multi tail-calls with `self_program_id` and `caller_program_id`. +//! +//! # Pattern +//! +//! A flash swap lets a program optimistically transfer tokens out, run arbitrary user +//! logic (the callback), then assert that invariants hold after the callback. The entire +//! sequence is a single atomic transaction: if any step fails, all state changes roll back. +//! +//! # How it works +//! +//! This program handles two instruction variants: +//! +//! - `Initiate` (external): the top-level entrypoint. Emits 3 chained calls: +//! 1. Token transfer out (vault → receiver) +//! 2. User callback (arbitrary logic, e.g. arbitrage) +//! 3. Self-call to `InvariantCheck` (using `self_program_id` to reference itself) +//! +//! - `InvariantCheck` (internal): enforces that the vault balance was restored after the callback. +//! Uses `caller_program_id == Some(self_program_id)` to prevent standalone calls (this is the +//! visibility enforcement mechanism). +//! +//! # What this demonstrates +//! +//! - `self_program_id`: enables a program to chain back to itself (step 3 above) +//! - `caller_program_id`: enables a program to restrict which callers can invoke an instruction +//! - Computed intermediate states: the initiator computes expected intermediate account states from +//! the `pre_states` and amount, keeping the instruction minimal. +//! - Atomic rollback: if the callback doesn't return funds, the invariant check fails, and all +//! state changes from steps 1 and 2 are rolled back automatically. +//! +//! # Tests +//! +//! See `nssa/src/state.rs` for integration tests: +//! - `flash_swap_successful`: full round-trip, funds returned, state unchanged +//! - `flash_swap_callback_keeps_funds_rollback`: callback keeps funds, full rollback +//! - `flash_swap_self_call_targets_correct_program`: zero-amount self-call isolation test +//! - `flash_swap_standalone_invariant_check_rejected`: `caller_program_id` access control + +use nssa_core::program::{ + AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput, + read_nssa_inputs, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub enum FlashSwapInstruction { + /// External entrypoint: initiate a flash swap. + /// + /// Emits 3 chained calls: + /// 1. Token transfer (vault → receiver, `amount_out`) + /// 2. Callback (user logic, e.g. arbitrage) + /// 3. Self-call `InvariantCheck` (verify vault balance did not decrease) + /// + /// Intermediate account states are computed inside the program from `pre_states` and + /// `amount_out`. + Initiate { + token_program_id: ProgramId, + callback_program_id: ProgramId, + amount_out: u128, + callback_instruction_data: Vec, + }, + /// Internal: verify the vault invariant holds after callback execution. + /// + /// Access control: only callable as a chained call from this program itself. + /// This is enforced by checking `caller_program_id == Some(self_program_id)`. + /// Any attempt to call this instruction as a standalone top-level transaction + /// will be rejected because `caller_program_id` will be `None`. + InvariantCheck { min_vault_balance: u128 }, +} + +fn main() { + let ( + ProgramInput { + self_program_id, + caller_program_id, + pre_states, + instruction, + }, + instruction_words, + ) = read_nssa_inputs::(); + + match instruction { + FlashSwapInstruction::Initiate { + token_program_id, + callback_program_id, + amount_out, + callback_instruction_data, + } => { + let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else { + panic!("Initiate requires exactly 2 accounts: vault, receiver"); + }; + + // Capture initial vault balance, the invariant check will verify it is restored. + let min_vault_balance = vault_pre.account.balance; + + // Compute intermediate account states from pre_states and amount_out. + let mut vault_after_transfer = vault_pre.clone(); + vault_after_transfer.account.balance = vault_pre + .account + .balance + .checked_sub(amount_out) + .expect("vault has insufficient balance for flash swap"); + + let mut receiver_after_transfer = receiver_pre.clone(); + receiver_after_transfer.account.balance = receiver_pre + .account + .balance + .checked_add(amount_out) + .expect("receiver balance overflow"); + + let mut vault_after_callback = vault_after_transfer.clone(); + vault_after_callback.account.balance = vault_after_transfer + .account + .balance + .checked_add(amount_out) + .expect("vault balance overflow after callback"); + + // Chained call 1: Token transfer (vault → receiver). + // The vault is a PDA of this initiator program (seed = [0_u8; 32]), so we provide + // the PDA seed to authorize the token program to debit the vault on our behalf. + // Mark the vault as authorized since it will be PDA-authorized in this chained call. + let mut vault_authorized = vault_pre.clone(); + vault_authorized.is_authorized = true; + let transfer_instruction = + risc0_zkvm::serde::to_vec(&amount_out).expect("transfer instruction serialization"); + let call_1 = ChainedCall { + program_id: token_program_id, + pre_states: vec![vault_authorized, receiver_pre.clone()], + instruction_data: transfer_instruction, + pda_seeds: vec![PdaSeed::new([0_u8; 32])], + }; + + // Chained call 2: User callback. + // Receives the post-transfer states as its pre_states. The callback may run + // arbitrary logic (arbitrage, etc.) and is expected to return funds to the vault. + let call_2 = ChainedCall { + program_id: callback_program_id, + pre_states: vec![vault_after_transfer, receiver_after_transfer], + instruction_data: callback_instruction_data, + pda_seeds: vec![], + }; + + // Chained call 3: Self-call to enforce the invariant. + // Uses `self_program_id` to reference this program, the key feature that enables + // the "prep → callback → assert" pattern without a separate checker program. + // If the callback did not return funds, vault_after_callback.balance < + // min_vault_balance and this call will panic, rolling back the entire + // transaction. + let invariant_instruction = + risc0_zkvm::serde::to_vec(&FlashSwapInstruction::InvariantCheck { + min_vault_balance, + }) + .expect("invariant instruction serialization"); + let call_3 = ChainedCall { + program_id: self_program_id, // self-referential chained call + pre_states: vec![vault_after_callback], + instruction_data: invariant_instruction, + pda_seeds: vec![], + }; + + // The initiator itself makes no direct state changes. + // All mutations happen inside the chained calls (token transfers). + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + vec![vault_pre.clone(), receiver_pre.clone()], + vec![ + AccountPostState::new(vault_pre.account), + AccountPostState::new(receiver_pre.account), + ], + ) + .with_chained_calls(vec![call_1, call_2, call_3]) + .write(); + } + + FlashSwapInstruction::InvariantCheck { min_vault_balance } => { + // Visibility enforcement: `InvariantCheck` is an internal instruction. + // It must only be called as a chained call from this program itself (via `Initiate`). + // When called as a top-level transaction, `caller_program_id` is `None` → panics. + // When called as a chained call from `Initiate`, `caller_program_id` is + // `Some(self_program_id)` → passes. + assert_eq!( + caller_program_id, + Some(self_program_id), + "InvariantCheck is an internal instruction: must be called by flash_swap_initiator \ + via a chained call", + ); + + let Ok([vault]) = <[_; 1]>::try_from(pre_states) else { + panic!("InvariantCheck requires exactly 1 account: vault"); + }; + + // The core invariant: vault balance must not have decreased. + // If the callback returned funds, this passes. If not, this panics and + // the entire transaction (including the prior token transfer) rolls back. + assert!( + vault.account.balance >= min_vault_balance, + "Flash swap invariant violated: vault balance {} < minimum {}", + vault.account.balance, + min_vault_balance + ); + + // Pass-through: no state changes in the invariant check step. + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + vec![vault.clone()], + vec![AccountPostState::new(vault.account)], + ) + .write(); + } + } +} diff --git a/test_program_methods/guest/src/bin/malicious_authorization_changer.rs b/test_program_methods/guest/src/bin/malicious_authorization_changer.rs index 1db09a73..f7aba4a0 100644 --- a/test_program_methods/guest/src/bin/malicious_authorization_changer.rs +++ b/test_program_methods/guest/src/bin/malicious_authorization_changer.rs @@ -15,6 +15,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (balance, transfer_program_id), }, @@ -42,6 +43,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![sender.clone(), receiver.clone()], vec![ diff --git a/test_program_methods/guest/src/bin/malicious_caller_program_id.rs b/test_program_methods/guest/src/bin/malicious_caller_program_id.rs new file mode 100644 index 00000000..2326190e --- /dev/null +++ b/test_program_methods/guest/src/bin/malicious_caller_program_id.rs @@ -0,0 +1,34 @@ +use nssa_core::program::{ + AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs, +}; + +type Instruction = (); + +fn main() { + let ( + ProgramInput { + self_program_id, + caller_program_id: _, // ignore the actual caller + pre_states, + instruction: (), + }, + instruction_words, + ) = read_nssa_inputs::(); + + let post_states = pre_states + .iter() + .map(|a| AccountPostState::new(a.account.clone())) + .collect(); + + // Deliberately output wrong caller_program_id. + // A real caller_program_id is None for a top-level call, so we spoof Some(DEFAULT_PROGRAM_ID) + // to simulate a program claiming it was invoked by another program when it was not. + ProgramOutput::new( + self_program_id, + Some(DEFAULT_PROGRAM_ID), // WRONG: should be None for a top-level call + instruction_words, + pre_states, + post_states, + ) + .write(); +} diff --git a/test_program_methods/guest/src/bin/malicious_self_program_id.rs b/test_program_methods/guest/src/bin/malicious_self_program_id.rs new file mode 100644 index 00000000..be447ab9 --- /dev/null +++ b/test_program_methods/guest/src/bin/malicious_self_program_id.rs @@ -0,0 +1,32 @@ +use nssa_core::program::{ + AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs, +}; + +type Instruction = (); + +fn main() { + let ( + ProgramInput { + self_program_id: _, // ignore the correct ID + caller_program_id, + pre_states, + instruction: (), + }, + instruction_words, + ) = read_nssa_inputs::(); + + let post_states = pre_states + .iter() + .map(|a| AccountPostState::new(a.account.clone())) + .collect(); + + // Deliberately output wrong self_program_id + ProgramOutput::new( + DEFAULT_PROGRAM_ID, // WRONG: should be self_program_id + caller_program_id, + instruction_words, + pre_states, + post_states, + ) + .write(); +} diff --git a/test_program_methods/guest/src/bin/minter.rs b/test_program_methods/guest/src/bin/minter.rs index 445df32f..1f31ca05 100644 --- a/test_program_methods/guest/src/bin/minter.rs +++ b/test_program_methods/guest/src/bin/minter.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, .. }, @@ -25,6 +26,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/missing_output.rs b/test_program_methods/guest/src/bin/missing_output.rs index 6b33d95e..d7d2778d 100644 --- a/test_program_methods/guest/src/bin/missing_output.rs +++ b/test_program_methods/guest/src/bin/missing_output.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, .. }, @@ -20,6 +21,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre1, pre2], vec![AccountPostState::new(account_pre1)], diff --git a/test_program_methods/guest/src/bin/modified_transfer.rs b/test_program_methods/guest/src/bin/modified_transfer.rs index 859f5cc0..2c05921c 100644 --- a/test_program_methods/guest/src/bin/modified_transfer.rs +++ b/test_program_methods/guest/src/bin/modified_transfer.rs @@ -65,6 +65,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: balance_to_move, }, @@ -81,5 +82,12 @@ fn main() { } _ => panic!("invalid params"), }; - ProgramOutput::new(self_program_id, instruction_data, pre_states, post_states).write(); + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_data, + pre_states, + post_states, + ) + .write(); } diff --git a/test_program_methods/guest/src/bin/nonce_changer.rs b/test_program_methods/guest/src/bin/nonce_changer.rs index 5e1cdbb2..c6e851fe 100644 --- a/test_program_methods/guest/src/bin/nonce_changer.rs +++ b/test_program_methods/guest/src/bin/nonce_changer.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, .. }, @@ -22,6 +23,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/noop.rs b/test_program_methods/guest/src/bin/noop.rs index 71787776..fc92aebe 100644 --- a/test_program_methods/guest/src/bin/noop.rs +++ b/test_program_methods/guest/src/bin/noop.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, .. }, @@ -16,5 +17,12 @@ fn main() { .iter() .map(|account| AccountPostState::new(account.account.clone())) .collect(); - ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); + ProgramOutput::new( + self_program_id, + caller_program_id, + instruction_words, + pre_states, + post_states, + ) + .write(); } diff --git a/test_program_methods/guest/src/bin/pinata_cooldown.rs b/test_program_methods/guest/src/bin/pinata_cooldown.rs index 1ea3465b..9e8bde3b 100644 --- a/test_program_methods/guest/src/bin/pinata_cooldown.rs +++ b/test_program_methods/guest/src/bin/pinata_cooldown.rs @@ -49,6 +49,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (), }, @@ -102,6 +103,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pinata, winner, clock_pre], vec![ diff --git a/test_program_methods/guest/src/bin/program_owner_changer.rs b/test_program_methods/guest/src/bin/program_owner_changer.rs index f1b2cfce..0282b5cc 100644 --- a/test_program_methods/guest/src/bin/program_owner_changer.rs +++ b/test_program_methods/guest/src/bin/program_owner_changer.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, .. }, @@ -22,6 +23,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/simple_balance_transfer.rs b/test_program_methods/guest/src/bin/simple_balance_transfer.rs index 4edd6198..f324b371 100644 --- a/test_program_methods/guest/src/bin/simple_balance_transfer.rs +++ b/test_program_methods/guest/src/bin/simple_balance_transfer.rs @@ -6,6 +6,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: balance, }, @@ -29,6 +30,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![sender_pre, receiver_pre], vec![ diff --git a/test_program_methods/guest/src/bin/time_locked_transfer.rs b/test_program_methods/guest/src/bin/time_locked_transfer.rs index 681d7fcd..25595661 100644 --- a/test_program_methods/guest/src/bin/time_locked_transfer.rs +++ b/test_program_methods/guest/src/bin/time_locked_transfer.rs @@ -19,6 +19,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (amount, deadline), }, @@ -58,6 +59,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![sender_pre, receiver_pre, clock_pre], vec![ diff --git a/test_program_methods/guest/src/bin/validity_window.rs b/test_program_methods/guest/src/bin/validity_window.rs index 67908836..03100e8e 100644 --- a/test_program_methods/guest/src/bin/validity_window.rs +++ b/test_program_methods/guest/src/bin/validity_window.rs @@ -9,6 +9,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (block_validity_window, timestamp_validity_window), }, @@ -23,6 +24,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new(post)], diff --git a/test_program_methods/guest/src/bin/validity_window_chain_caller.rs b/test_program_methods/guest/src/bin/validity_window_chain_caller.rs index cbe3c7c1..212418a2 100644 --- a/test_program_methods/guest/src/bin/validity_window_chain_caller.rs +++ b/test_program_methods/guest/src/bin/validity_window_chain_caller.rs @@ -17,6 +17,7 @@ fn main() { let ( ProgramInput { self_program_id, + caller_program_id, pre_states, instruction: (block_validity_window, chained_program_id, chained_block_validity_window), }, @@ -40,6 +41,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pre], vec![AccountPostState::new(post)],