fix: compute intermediate states inside flash swap programs

This commit is contained in:
Moudy 2026-04-03 22:05:49 +02:00 committed by moudyellaz
parent c85f19fe85
commit 3cfc74695b
3 changed files with 35 additions and 182 deletions

View File

@ -467,24 +467,15 @@ pub mod tests {
return_funds: bool,
token_program_id: ProgramId,
amount: u128,
vault_after_return: Option<AccountWithMetadata>,
receiver_after_return: Option<AccountWithMetadata>,
}
#[derive(serde::Serialize, serde::Deserialize)]
#[expect(
clippy::large_enum_variant,
reason = "test-only mirror of guest enum, boxing unnecessary"
)]
enum FlashSwapInstruction {
Initiate {
token_program_id: ProgramId,
callback_program_id: ProgramId,
amount_out: u128,
callback_instruction_data: Vec<u32>,
vault_after_transfer: AccountWithMetadata,
receiver_after_transfer: AccountWithMetadata,
vault_after_callback: AccountWithMetadata,
},
InvariantCheck {
min_vault_balance: u128,
@ -3555,61 +3546,11 @@ pub mod tests {
state.force_insert_account(vault_id, vault_account);
state.force_insert_account(receiver_id, receiver_account);
// Pre-simulated intermediate states:
// After transfer (vault→receiver, amount_out):
let vault_after_transfer = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance - amount_out,
..Account::default()
},
false,
vault_id,
);
let receiver_after_transfer = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: amount_out,
..Account::default()
},
false,
receiver_id,
);
// After callback returns funds (receiver→vault, amount_out):
let vault_after_callback = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
},
false,
vault_id,
);
// Callback instruction: return funds
let cb_instruction = CallbackInstruction {
return_funds: true,
token_program_id: token.id(),
amount: amount_out,
vault_after_return: Some(AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
},
false,
vault_id,
)),
receiver_after_return: Some(AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: 0,
..Account::default()
},
false,
receiver_id,
)),
};
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
@ -3618,9 +3559,6 @@ pub mod tests {
callback_program_id: callback.id(),
amount_out,
callback_instruction_data: cb_data,
vault_after_transfer,
receiver_after_transfer,
vault_after_callback,
};
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
@ -3659,44 +3597,11 @@ pub mod tests {
state.force_insert_account(vault_id, vault_account);
state.force_insert_account(receiver_id, receiver_account);
// Pre-simulated intermediate states (same as successful case for steps 1-2):
let vault_after_transfer = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance - amount_out,
..Account::default()
},
false,
vault_id,
);
let receiver_after_transfer = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: amount_out,
..Account::default()
},
false,
receiver_id,
);
// After callback that does NOT return funds — vault stays drained:
let vault_after_callback = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance - amount_out,
..Account::default()
},
false,
vault_id,
);
// Callback instruction: do NOT return funds
let cb_instruction = CallbackInstruction {
return_funds: false,
token_program_id: token.id(),
amount: amount_out,
vault_after_return: None,
receiver_after_return: None,
};
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
@ -3705,9 +3610,6 @@ pub mod tests {
callback_program_id: callback.id(),
amount_out,
callback_instruction_data: cb_data,
vault_after_transfer,
receiver_after_transfer,
vault_after_callback,
};
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
@ -3752,58 +3654,10 @@ pub mod tests {
state.force_insert_account(vault_id, vault_account);
state.force_insert_account(receiver_id, receiver_account);
// Zero-amount transfer: states remain unchanged after transfer
let vault_after_transfer = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
},
false,
vault_id,
);
let receiver_after_transfer = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: 0,
..Account::default()
},
false,
receiver_id,
);
// Callback with zero amount, return_funds=true (no-op effectively)
let vault_after_callback = AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
},
false,
vault_id,
);
let cb_instruction = CallbackInstruction {
return_funds: true,
token_program_id: token.id(),
amount: 0,
vault_after_return: Some(AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
},
false,
vault_id,
)),
receiver_after_return: Some(AccountWithMetadata::new(
Account {
program_owner: token.id(),
balance: 0,
..Account::default()
},
false,
receiver_id,
)),
};
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
@ -3812,9 +3666,6 @@ pub mod tests {
callback_program_id: callback.id(),
amount_out: 0,
callback_instruction_data: cb_data,
vault_after_transfer,
receiver_after_transfer,
vault_after_callback,
};
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);

View File

@ -24,12 +24,9 @@
//! called by any program. In production, a callback would typically verify the caller
//! if it needs to trust the context it is called from.
use nssa_core::{
account::AccountWithMetadata,
program::{
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
read_nssa_inputs,
},
use nssa_core::program::{
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
read_nssa_inputs,
};
use serde::{Deserialize, Serialize};
@ -40,10 +37,6 @@ pub struct CallbackInstruction {
pub return_funds: bool,
pub token_program_id: ProgramId,
pub amount: u128,
/// Pre-simulated vault state after the return transfer (required if `return_funds = true`).
pub vault_after_return: Option<AccountWithMetadata>,
/// Pre-simulated receiver state after the return transfer (required if `return_funds = true`).
pub receiver_after_return: Option<AccountWithMetadata>,
}
fn main() {

View File

@ -24,8 +24,8 @@
//!
//! - `self_program_id`: enables a program to chain back to itself (step 3 above)
//! - `caller_program_id`: enables a program to restrict which callers can invoke an instruction
//! - Pre-simulated intermediate states: the initiator must compute expected intermediate account
//! states and embed them in the instruction. The node validates them deterministically.
//! - Computed intermediate states: the initiator computes expected intermediate account
//! states from the pre_states and amount, keeping the instruction minimal.
//! - Atomic rollback: if the callback doesn't return funds, the invariant check fails, and all
//! state changes from steps 1 and 2 are rolled back automatically.
//!
@ -37,12 +37,9 @@
//! - `flash_swap_self_call_targets_correct_program`: zero-amount self-call isolation test
//! - `flash_swap_standalone_invariant_check_rejected`: `caller_program_id` access control
use nssa_core::{
account::AccountWithMetadata,
program::{
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
read_nssa_inputs,
},
use nssa_core::program::{
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
read_nssa_inputs,
};
use serde::{Deserialize, Serialize};
@ -59,20 +56,12 @@ pub enum FlashSwapInstruction {
/// 2. Callback (user logic, e.g. arbitrage)
/// 3. Self-call `InvariantCheck` (verify vault balance did not decrease)
///
/// The caller must pre-simulate the entire call graph and provide the expected
/// intermediate account states. The node validates them deterministically at each step.
/// Intermediate account states are computed inside the program from pre_states and amount_out.
Initiate {
token_program_id: ProgramId,
callback_program_id: ProgramId,
amount_out: u128,
callback_instruction_data: Vec<u32>,
/// Expected vault state after the token transfer (vault balance -= `amount_out`).
vault_after_transfer: AccountWithMetadata,
/// Expected receiver state after the token transfer (receiver balance += `amount_out`).
receiver_after_transfer: AccountWithMetadata,
/// Expected vault state after the callback completes (should match initial balance
/// if the callback correctly returns funds).
vault_after_callback: AccountWithMetadata,
},
/// Internal: verify the vault invariant holds after callback execution.
///
@ -100,9 +89,6 @@ fn main() {
callback_program_id,
amount_out,
callback_instruction_data,
vault_after_transfer,
receiver_after_transfer,
vault_after_callback,
} => {
let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else {
panic!("Initiate requires exactly 2 accounts: vault, receiver");
@ -111,6 +97,28 @@ fn main() {
// Capture initial vault balance, the invariant check will verify it is restored.
let min_vault_balance = vault_pre.account.balance;
// Compute intermediate account states from pre_states and amount_out.
let mut vault_after_transfer = vault_pre.clone();
vault_after_transfer.account.balance = vault_pre
.account
.balance
.checked_sub(amount_out)
.expect("vault has insufficient balance for flash swap");
let mut receiver_after_transfer = receiver_pre.clone();
receiver_after_transfer.account.balance = receiver_pre
.account
.balance
.checked_add(amount_out)
.expect("receiver balance overflow");
let mut vault_after_callback = vault_after_transfer.clone();
vault_after_callback.account.balance = vault_after_transfer
.account
.balance
.checked_add(amount_out)
.expect("vault balance overflow after callback");
// Chained call 1: Token transfer (vault → receiver).
// The vault is a PDA of this initiator program (seed = [0_u8; 32]), so we provide
// the PDA seed to authorize the token program to debit the vault on our behalf.
@ -175,10 +183,11 @@ fn main() {
// When called as a top-level transaction, `caller_program_id` is `None` → panics.
// When called as a chained call from `Initiate`, `caller_program_id` is
// `Some(self_program_id)` → passes.
assert!(
caller_program_id == Some(self_program_id),
assert_eq!(
caller_program_id,
Some(self_program_id),
"InvariantCheck is an internal instruction: must be called by flash_swap_initiator \
via a chained call, got caller_program_id: {caller_program_id:?}",
via a chained call",
);
let Ok([vault]) = <[_; 1]>::try_from(pre_states) else {