From 3cfc74695bbb8fac6834ea9e3f98e0dc7d676243 Mon Sep 17 00:00:00 2001 From: Moudy Date: Fri, 3 Apr 2026 22:05:49 +0200 Subject: [PATCH] fix: compute intermediate states inside flash swap programs --- nssa/src/state.rs | 149 ------------------ .../guest/src/bin/flash_swap_callback.rs | 13 +- .../guest/src/bin/flash_swap_initiator.rs | 55 ++++--- 3 files changed, 35 insertions(+), 182 deletions(-) diff --git a/nssa/src/state.rs b/nssa/src/state.rs index 7ddb9a5d..aa41e856 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -467,24 +467,15 @@ pub mod tests { return_funds: bool, token_program_id: ProgramId, amount: u128, - vault_after_return: Option, - receiver_after_return: Option, } #[derive(serde::Serialize, serde::Deserialize)] - #[expect( - clippy::large_enum_variant, - reason = "test-only mirror of guest enum, boxing unnecessary" - )] enum FlashSwapInstruction { Initiate { token_program_id: ProgramId, callback_program_id: ProgramId, amount_out: u128, callback_instruction_data: Vec, - vault_after_transfer: AccountWithMetadata, - receiver_after_transfer: AccountWithMetadata, - vault_after_callback: AccountWithMetadata, }, InvariantCheck { min_vault_balance: u128, @@ -3555,61 +3546,11 @@ pub mod tests { state.force_insert_account(vault_id, vault_account); state.force_insert_account(receiver_id, receiver_account); - // Pre-simulated intermediate states: - // After transfer (vault→receiver, amount_out): - let vault_after_transfer = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance - amount_out, - ..Account::default() - }, - false, - vault_id, - ); - let receiver_after_transfer = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: amount_out, - ..Account::default() - }, - false, - receiver_id, - ); - - // After callback returns funds (receiver→vault, amount_out): - let vault_after_callback = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance, - ..Account::default() - }, - false, - vault_id, - ); - // Callback instruction: return funds let cb_instruction = CallbackInstruction { return_funds: true, token_program_id: token.id(), amount: amount_out, - vault_after_return: Some(AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance, - ..Account::default() - }, - false, - vault_id, - )), - receiver_after_return: Some(AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: 0, - ..Account::default() - }, - false, - receiver_id, - )), }; let cb_data = Program::serialize_instruction(cb_instruction).unwrap(); @@ -3618,9 +3559,6 @@ pub mod tests { callback_program_id: callback.id(), amount_out, callback_instruction_data: cb_data, - vault_after_transfer, - receiver_after_transfer, - vault_after_callback, }; let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction); @@ -3659,44 +3597,11 @@ pub mod tests { state.force_insert_account(vault_id, vault_account); state.force_insert_account(receiver_id, receiver_account); - // Pre-simulated intermediate states (same as successful case for steps 1-2): - let vault_after_transfer = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance - amount_out, - ..Account::default() - }, - false, - vault_id, - ); - let receiver_after_transfer = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: amount_out, - ..Account::default() - }, - false, - receiver_id, - ); - - // After callback that does NOT return funds — vault stays drained: - let vault_after_callback = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance - amount_out, - ..Account::default() - }, - false, - vault_id, - ); - // Callback instruction: do NOT return funds let cb_instruction = CallbackInstruction { return_funds: false, token_program_id: token.id(), amount: amount_out, - vault_after_return: None, - receiver_after_return: None, }; let cb_data = Program::serialize_instruction(cb_instruction).unwrap(); @@ -3705,9 +3610,6 @@ pub mod tests { callback_program_id: callback.id(), amount_out, callback_instruction_data: cb_data, - vault_after_transfer, - receiver_after_transfer, - vault_after_callback, }; let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction); @@ -3752,58 +3654,10 @@ pub mod tests { state.force_insert_account(vault_id, vault_account); state.force_insert_account(receiver_id, receiver_account); - // Zero-amount transfer: states remain unchanged after transfer - let vault_after_transfer = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance, - ..Account::default() - }, - false, - vault_id, - ); - let receiver_after_transfer = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: 0, - ..Account::default() - }, - false, - receiver_id, - ); - // Callback with zero amount, return_funds=true (no-op effectively) - let vault_after_callback = AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance, - ..Account::default() - }, - false, - vault_id, - ); - let cb_instruction = CallbackInstruction { return_funds: true, token_program_id: token.id(), amount: 0, - vault_after_return: Some(AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: initial_balance, - ..Account::default() - }, - false, - vault_id, - )), - receiver_after_return: Some(AccountWithMetadata::new( - Account { - program_owner: token.id(), - balance: 0, - ..Account::default() - }, - false, - receiver_id, - )), }; let cb_data = Program::serialize_instruction(cb_instruction).unwrap(); @@ -3812,9 +3666,6 @@ pub mod tests { callback_program_id: callback.id(), amount_out: 0, callback_instruction_data: cb_data, - vault_after_transfer, - receiver_after_transfer, - vault_after_callback, }; let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction); diff --git a/test_program_methods/guest/src/bin/flash_swap_callback.rs b/test_program_methods/guest/src/bin/flash_swap_callback.rs index 24e1c853..e0bdf3ed 100644 --- a/test_program_methods/guest/src/bin/flash_swap_callback.rs +++ b/test_program_methods/guest/src/bin/flash_swap_callback.rs @@ -24,12 +24,9 @@ //! called by any program. In production, a callback would typically verify the caller //! if it needs to trust the context it is called from. -use nssa_core::{ - account::AccountWithMetadata, - program::{ - AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput, - read_nssa_inputs, - }, +use nssa_core::program::{ + AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput, + read_nssa_inputs, }; use serde::{Deserialize, Serialize}; @@ -40,10 +37,6 @@ pub struct CallbackInstruction { pub return_funds: bool, pub token_program_id: ProgramId, pub amount: u128, - /// Pre-simulated vault state after the return transfer (required if `return_funds = true`). - pub vault_after_return: Option, - /// Pre-simulated receiver state after the return transfer (required if `return_funds = true`). - pub receiver_after_return: Option, } fn main() { diff --git a/test_program_methods/guest/src/bin/flash_swap_initiator.rs b/test_program_methods/guest/src/bin/flash_swap_initiator.rs index 8f1c28fb..7c178b28 100644 --- a/test_program_methods/guest/src/bin/flash_swap_initiator.rs +++ b/test_program_methods/guest/src/bin/flash_swap_initiator.rs @@ -24,8 +24,8 @@ //! //! - `self_program_id`: enables a program to chain back to itself (step 3 above) //! - `caller_program_id`: enables a program to restrict which callers can invoke an instruction -//! - Pre-simulated intermediate states: the initiator must compute expected intermediate account -//! states and embed them in the instruction. The node validates them deterministically. +//! - Computed intermediate states: the initiator computes expected intermediate account +//! states from the pre_states and amount, keeping the instruction minimal. //! - Atomic rollback: if the callback doesn't return funds, the invariant check fails, and all //! state changes from steps 1 and 2 are rolled back automatically. //! @@ -37,12 +37,9 @@ //! - `flash_swap_self_call_targets_correct_program`: zero-amount self-call isolation test //! - `flash_swap_standalone_invariant_check_rejected`: `caller_program_id` access control -use nssa_core::{ - account::AccountWithMetadata, - program::{ - AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput, - read_nssa_inputs, - }, +use nssa_core::program::{ + AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput, + read_nssa_inputs, }; use serde::{Deserialize, Serialize}; @@ -59,20 +56,12 @@ pub enum FlashSwapInstruction { /// 2. Callback (user logic, e.g. arbitrage) /// 3. Self-call `InvariantCheck` (verify vault balance did not decrease) /// - /// The caller must pre-simulate the entire call graph and provide the expected - /// intermediate account states. The node validates them deterministically at each step. + /// Intermediate account states are computed inside the program from pre_states and amount_out. Initiate { token_program_id: ProgramId, callback_program_id: ProgramId, amount_out: u128, callback_instruction_data: Vec, - /// Expected vault state after the token transfer (vault balance -= `amount_out`). - vault_after_transfer: AccountWithMetadata, - /// Expected receiver state after the token transfer (receiver balance += `amount_out`). - receiver_after_transfer: AccountWithMetadata, - /// Expected vault state after the callback completes (should match initial balance - /// if the callback correctly returns funds). - vault_after_callback: AccountWithMetadata, }, /// Internal: verify the vault invariant holds after callback execution. /// @@ -100,9 +89,6 @@ fn main() { callback_program_id, amount_out, callback_instruction_data, - vault_after_transfer, - receiver_after_transfer, - vault_after_callback, } => { let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else { panic!("Initiate requires exactly 2 accounts: vault, receiver"); @@ -111,6 +97,28 @@ fn main() { // Capture initial vault balance, the invariant check will verify it is restored. let min_vault_balance = vault_pre.account.balance; + // Compute intermediate account states from pre_states and amount_out. + let mut vault_after_transfer = vault_pre.clone(); + vault_after_transfer.account.balance = vault_pre + .account + .balance + .checked_sub(amount_out) + .expect("vault has insufficient balance for flash swap"); + + let mut receiver_after_transfer = receiver_pre.clone(); + receiver_after_transfer.account.balance = receiver_pre + .account + .balance + .checked_add(amount_out) + .expect("receiver balance overflow"); + + let mut vault_after_callback = vault_after_transfer.clone(); + vault_after_callback.account.balance = vault_after_transfer + .account + .balance + .checked_add(amount_out) + .expect("vault balance overflow after callback"); + // Chained call 1: Token transfer (vault → receiver). // The vault is a PDA of this initiator program (seed = [0_u8; 32]), so we provide // the PDA seed to authorize the token program to debit the vault on our behalf. @@ -175,10 +183,11 @@ fn main() { // When called as a top-level transaction, `caller_program_id` is `None` → panics. // When called as a chained call from `Initiate`, `caller_program_id` is // `Some(self_program_id)` → passes. - assert!( - caller_program_id == Some(self_program_id), + assert_eq!( + caller_program_id, + Some(self_program_id), "InvariantCheck is an internal instruction: must be called by flash_swap_initiator \ - via a chained call, got caller_program_id: {caller_program_id:?}", + via a chained call", ); let Ok([vault]) = <[_; 1]>::try_from(pre_states) else {