diff --git a/amm/amm-idl.json b/amm/amm-idl.json index e3b1747..a1ef5f9 100644 --- a/amm/amm-idl.json +++ b/amm/amm-idl.json @@ -240,6 +240,55 @@ } ] }, + { + "name": "swap_exact_output", + "accounts": [ + { + "name": "pool", + "writable": false, + "signer": false, + "init": false + }, + { + "name": "vault_a", + "writable": false, + "signer": false, + "init": false + }, + { + "name": "vault_b", + "writable": false, + "signer": false, + "init": false + }, + { + "name": "user_holding_a", + "writable": false, + "signer": false, + "init": false + }, + { + "name": "user_holding_b", + "writable": false, + "signer": false, + "init": false + } + ], + "args": [ + { + "name": "exact_amount_out", + "type": "u128" + }, + { + "name": "max_amount_in", + "type": "u128" + }, + { + "name": "token_definition_id_in", + "type": "account_id" + } + ] + }, { "name": "sync_reserves", "accounts": [ diff --git a/amm/core/src/lib.rs b/amm/core/src/lib.rs index 84977ea..a883ed1 100644 --- a/amm/core/src/lib.rs +++ b/amm/core/src/lib.rs @@ -84,6 +84,22 @@ pub enum Instruction { token_definition_id_in: AccountId, }, + /// Swap tokens specifying the exact desired output amount, + /// while maintaining the Pool constant product. + /// + /// Required accounts: + /// - AMM Pool (initialized) + /// - Vault Holding Account for Token A (initialized) + /// - Vault Holding Account for Token B (initialized) + /// - User Holding Account for Token A + /// - User Holding Account for Token B Either User Holding Account for Token A or Token B is + /// authorized. + SwapExactOutput { + exact_amount_out: u128, + max_amount_in: u128, + token_definition_id_in: AccountId, + }, + /// Sync pool reserves with current vault balances. /// /// Required accounts: diff --git a/amm/methods/guest/src/bin/amm.rs b/amm/methods/guest/src/bin/amm.rs index 5ea0cbb..dd77dab 100644 --- a/amm/methods/guest/src/bin/amm.rs +++ b/amm/methods/guest/src/bin/amm.rs @@ -130,6 +130,31 @@ mod amm { Ok(SpelOutput::with_chained_calls(post_states, chained_calls)) } + /// Swap tokens specifying the exact desired output amount. + #[instruction] + pub fn swap_exact_output( + pool: AccountWithMetadata, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + token_definition_id_in: AccountId, + ) -> SpelResult { + let (post_states, chained_calls) = amm_program::swap::swap_exact_output( + pool, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition_id_in, + ); + Ok(SpelOutput::with_chained_calls(post_states, chained_calls)) + } + /// Sync pool reserves with current vault balances. #[instruction] pub fn sync_reserves( diff --git a/amm/src/swap.rs b/amm/src/swap.rs index 759c83d..3aae808 100644 --- a/amm/src/swap.rs +++ b/amm/src/swap.rs @@ -4,20 +4,14 @@ use nssa_core::{ program::{AccountPostState, ChainedCall}, }; -#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] -pub fn swap( - pool: AccountWithMetadata, - vault_a: AccountWithMetadata, - vault_b: AccountWithMetadata, - user_holding_a: AccountWithMetadata, - user_holding_b: AccountWithMetadata, - swap_amount_in: u128, - min_amount_out: u128, - token_in_id: AccountId, -) -> (Vec, Vec) { - // Verify vaults are in fact vaults +/// Validates swap setup: checks pool is active, vaults match, and reserves are sufficient. +fn validate_swap_setup( + pool: &AccountWithMetadata, + vault_a: &AccountWithMetadata, + vault_b: &AccountWithMetadata, +) -> PoolDefinition { let pool_def_data = PoolDefinition::try_from(&pool.account.data) - .expect("Swap: AMM Program expects a valid Pool Definition Account"); + .expect("AMM Program expects a valid Pool Definition Account"); assert!(pool_def_data.active, "Pool is inactive"); assert_eq!( @@ -29,16 +23,14 @@ pub fn swap( "Vault B was not provided" ); - // fetch pool reserves - // validates reserves is at least the vaults' balances let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) - .expect("Swap: AMM Program expects a valid Token Holding Account for Vault A"); + .expect("AMM Program expects a valid Token Holding Account for Vault A"); let token_core::TokenHolding::Fungible { definition_id: _, balance: vault_a_balance, } = vault_a_token_holding else { - panic!("Swap: AMM Program expects a valid Fungible Token Holding Account for Vault A"); + panic!("AMM Program expects a valid Fungible Token Holding Account for Vault A"); }; assert!( @@ -47,13 +39,13 @@ pub fn swap( ); let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) - .expect("Swap: AMM Program expects a valid Token Holding Account for Vault B"); + .expect("AMM Program expects a valid Token Holding Account for Vault B"); let token_core::TokenHolding::Fungible { definition_id: _, balance: vault_b_balance, } = vault_b_token_holding else { - panic!("Swap: AMM Program expects a valid Fungible Token Holding Account for Vault B"); + panic!("AMM Program expects a valid Fungible Token Holding Account for Vault B"); }; assert!( @@ -61,6 +53,59 @@ pub fn swap( "Reserve for Token B exceeds vault balance" ); + pool_def_data +} + +/// Creates post-state and returns reserves after swap. +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[expect( + clippy::needless_pass_by_value, + reason = "consistent with codebase style" +)] +fn create_swap_post_states( + pool: AccountWithMetadata, + pool_def_data: PoolDefinition, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + deposit_a: u128, + withdraw_a: u128, + deposit_b: u128, + withdraw_b: u128, +) -> Vec { + let mut pool_post = pool.account; + let pool_post_definition = PoolDefinition { + reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, + reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, + ..pool_def_data + }; + + pool_post.data = Data::from(&pool_post_definition); + + vec![ + AccountPostState::new(pool_post), + AccountPostState::new(vault_a.account), + AccountPostState::new(vault_b.account), + AccountPostState::new(user_holding_a.account), + AccountPostState::new(user_holding_b.account), + ] +} + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[must_use] +pub fn swap( + pool: AccountWithMetadata, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + swap_amount_in: u128, + min_amount_out: u128, + token_in_id: AccountId, +) -> (Vec, Vec) { + let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = if token_in_id == pool_def_data.definition_token_a_id { let (chained_calls, deposit_a, withdraw_b) = swap_logic( @@ -94,23 +139,18 @@ pub fn swap( panic!("AccountId is not a token type for the pool"); }; - // Update pool account - let mut pool_post = pool.account.clone(); - let pool_post_definition = PoolDefinition { - reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, - reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, - ..pool_def_data - }; - - pool_post.data = Data::from(&pool_post_definition); - - let post_states = vec![ - AccountPostState::new(pool_post.clone()), - AccountPostState::new(vault_a.account.clone()), - AccountPostState::new(vault_b.account.clone()), - AccountPostState::new(user_holding_a.account.clone()), - AccountPostState::new(user_holding_b.account.clone()), - ]; + let post_states = create_swap_post_states( + pool, + pool_def_data, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + deposit_a, + withdraw_a, + deposit_b, + withdraw_b, + ); (post_states, chained_calls) } @@ -130,7 +170,9 @@ fn swap_logic( // Compute withdraw amount // Maintains pool constant product // k = pool_def_data.reserve_a * pool_def_data.reserve_b; - let withdraw_amount = (reserve_withdraw_vault_amount * swap_amount_in) + let withdraw_amount = reserve_withdraw_vault_amount + .checked_mul(swap_amount_in) + .expect("reserve * amount_in overflows u128") / (reserve_deposit_vault_amount + swap_amount_in); // Slippage check @@ -174,3 +216,135 @@ fn swap_logic( (chained_calls, swap_amount_in, withdraw_amount) } + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[must_use] +pub fn swap_exact_output( + pool: AccountWithMetadata, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + token_in_id: AccountId, +) -> (Vec, Vec) { + let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = + if token_in_id == pool_def_data.definition_token_a_id { + let (chained_calls, deposit_a, withdraw_b) = exact_output_swap_logic( + user_holding_a.clone(), + vault_a.clone(), + vault_b.clone(), + user_holding_b.clone(), + exact_amount_out, + max_amount_in, + pool_def_data.reserve_a, + pool_def_data.reserve_b, + pool.account_id, + ); + + (chained_calls, [deposit_a, 0], [0, withdraw_b]) + } else if token_in_id == pool_def_data.definition_token_b_id { + let (chained_calls, deposit_b, withdraw_a) = exact_output_swap_logic( + user_holding_b.clone(), + vault_b.clone(), + vault_a.clone(), + user_holding_a.clone(), + exact_amount_out, + max_amount_in, + pool_def_data.reserve_b, + pool_def_data.reserve_a, + pool.account_id, + ); + + (chained_calls, [0, withdraw_a], [deposit_b, 0]) + } else { + panic!("AccountId is not a token type for the pool"); + }; + + let post_states = create_swap_post_states( + pool, + pool_def_data, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + deposit_a, + withdraw_a, + deposit_b, + withdraw_b, + ); + + (post_states, chained_calls) +} + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +fn exact_output_swap_logic( + user_deposit: AccountWithMetadata, + vault_deposit: AccountWithMetadata, + vault_withdraw: AccountWithMetadata, + user_withdraw: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + reserve_deposit_vault_amount: u128, + reserve_withdraw_vault_amount: u128, + pool_id: AccountId, +) -> (Vec, u128, u128) { + // Guard: exact_amount_out must be nonzero + assert_ne!(exact_amount_out, 0, "Exact amount out must be nonzero"); + + // Guard: exact_amount_out must be less than reserve_withdraw_vault_amount + assert!( + exact_amount_out < reserve_withdraw_vault_amount, + "Exact amount out exceeds reserve" + ); + + // Compute deposit amount using ceiling division + // Formula: amount_in = ceil(reserve_in * exact_amount_out / (reserve_out - exact_amount_out)) + let deposit_amount = reserve_deposit_vault_amount + .checked_mul(exact_amount_out) + .expect("reserve * amount_out overflows u128") + .div_ceil(reserve_withdraw_vault_amount - exact_amount_out); + + // Slippage check + assert!( + deposit_amount <= max_amount_in, + "Required input exceeds maximum amount in" + ); + + let token_program_id = user_deposit.account.program_owner; + + let mut chained_calls = Vec::new(); + chained_calls.push(ChainedCall::new( + token_program_id, + vec![user_deposit, vault_deposit], + &token_core::Instruction::Transfer { + amount_to_transfer: deposit_amount, + }, + )); + + let mut vault_withdraw = vault_withdraw; + vault_withdraw.is_authorized = true; + + let pda_seed = compute_vault_pda_seed( + pool_id, + token_core::TokenHolding::try_from(&vault_withdraw.account.data) + .expect("Exact Output Swap Logic: AMM Program expects valid token data") + .definition_id(), + ); + + chained_calls.push( + ChainedCall::new( + token_program_id, + vec![vault_withdraw, user_withdraw], + &token_core::Instruction::Transfer { + amount_to_transfer: exact_amount_out, + }, + ) + .with_pda_seeds(vec![pda_seed]), + ); + + (chained_calls, deposit_amount, exact_amount_out) +} diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 341dc4f..489f8a4 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -13,7 +13,10 @@ use nssa_core::{ use token_core::{TokenDefinition, TokenHolding}; use crate::{ - add::add_liquidity, new_definition::new_definition, remove::remove_liquidity, swap::swap, + add::add_liquidity, + new_definition::new_definition, + remove::remove_liquidity, + swap::{swap, swap_exact_output}, sync::sync_reserves, }; @@ -157,6 +160,10 @@ impl BalanceForTests { BalanceForTests::add_max_amount_b() } + fn max_amount_in() -> u128 { + 166 + } + fn vault_a_remove_successful() -> u128 { BalanceForTests::vault_a_reserve_init() - BalanceForTests::remove_actual_a_successful() } @@ -263,6 +270,74 @@ impl ChainedCallForTests { ) } + fn cc_swap_exact_output_token_a_test_1() -> ChainedCall { + let swap_amount: u128 = 498; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::vault_a_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + } + + fn cc_swap_exact_output_token_b_test_1() -> ChainedCall { + let swap_amount: u128 = 166; + + let mut vault_b_auth = AccountWithMetadataForTests::vault_b_init(); + vault_b_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_b_definition_id(), + )]) + } + + fn cc_swap_exact_output_token_a_test_2() -> ChainedCall { + let swap_amount: u128 = 285; + + let mut vault_a_auth = AccountWithMetadataForTests::vault_a_init(); + vault_a_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_a_auth, AccountWithMetadataForTests::user_holding_a()], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_a_definition_id(), + )]) + } + + fn cc_swap_exact_output_token_b_test_2() -> ChainedCall { + let swap_amount: u128 = 200; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::vault_b_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + } + fn cc_add_token_a() -> ChainedCall { ChainedCall::new( TOKEN_PROGRAM_ID, @@ -804,6 +879,34 @@ impl AccountWithMetadataForTests { } } + /// A smaller pool (reserve_a=1_000, reserve_b=500) used exclusively by + /// swap-exact-output tests, whose hardcoded expected values were computed + /// against these reserves. vault_a_init/vault_b_init still satisfy the + /// balance ≥ reserve check (5_000 ≥ 1_000, 2_500 ≥ 500). + fn pool_definition_swap_exact_output_init() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: 1_000, + reserve_b: 500, + fees: 0u128, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_init_reserve_a_zero() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -948,6 +1051,54 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_swap_exact_output_test_1() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: 1498_u128, + reserve_b: 334_u128, + fees: 0_u128, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + + fn pool_definition_swap_exact_output_test_2() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: 715_u128, + reserve_b: 700_u128, + fees: 0_u128, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_add_zero_lp() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -2017,6 +2168,281 @@ fn test_call_swap_chained_call_successful_2() { ); } +#[should_panic(expected = "AccountId is not a token type for the pool")] +#[test] +fn call_swap_exact_output_incorrect_token_type() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_lp_definition_id(), + ); +} + +#[should_panic(expected = "Vault A was not provided")] +#[test] +fn call_swap_exact_output_vault_a_omitted() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_with_wrong_id(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Vault B was not provided")] +#[test] +fn call_swap_exact_output_vault_b_omitted() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_with_wrong_id(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Reserve for Token A exceeds vault balance")] +#[test] +fn call_swap_exact_output_reserves_vault_mismatch_1() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Reserve for Token B exceeds vault balance")] +#[test] +fn call_swap_exact_output_reserves_vault_mismatch_2() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Pool is inactive")] +#[test] +fn call_swap_exact_output_inactive() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_inactive(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Required input exceeds maximum amount in")] +#[test] +fn call_swap_exact_output_exceeds_max_in() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 166_u128, + 100_u128, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Exact amount out must be nonzero")] +#[test] +fn call_swap_exact_output_zero() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 0_u128, + 500_u128, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Exact amount out exceeds reserve")] +#[test] +fn call_swap_exact_output_exceeds_reserve() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::vault_b_reserve_init(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[test] +fn call_swap_exact_output_chained_call_successful() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_exact_output_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::max_amount_in(), + BalanceForTests::vault_b_reserve_init(), + IdForTests::token_a_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_swap_exact_output_test_1().account + == *pool_post.account() + ); + + let chained_call_a = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_exact_output_token_a_test_1() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_exact_output_token_b_test_1() + ); +} + +#[test] +fn call_swap_exact_output_chained_call_successful_2() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_exact_output_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 285, + 300, + IdForTests::token_b_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_swap_exact_output_test_2().account + == *pool_post.account() + ); + + let chained_call_a = chained_calls[1].clone(); + let chained_call_b = chained_calls[0].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_exact_output_token_a_test_2() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_exact_output_token_b_test_2() + ); +} + +// Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, +// making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes, +// so an attacker receives `exact_amount_out` tokens while paying nothing. +#[should_panic(expected = "reserve * amount_out overflows u128")] +#[test] +fn swap_exact_output_overflow_protection() { + // reserve_a chosen so that reserve_a * 2 overflows u128: + // (u128::MAX / 2 + 1) * 2 = u128::MAX + 1 → wraps to 0 + let large_reserve: u128 = u128::MAX / 2 + 1; + let reserve_b: u128 = 1_000; + + let pool = AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: 1, + reserve_a: large_reserve, + reserve_b, + fees: 0, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + }; + + let vault_a = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: large_reserve, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::vault_a_id(), + }; + + let vault_b = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: reserve_b, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::vault_b_id(), + }; + + let _result = swap_exact_output( + pool, + vault_a, + vault_b, + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, // exact_amount_out: small, valid (< reserve_b) + 1, // max_amount_in: tiny — real deposit would be enormous, but + // overflow wraps it to 0, making 0 <= 1 pass silently + IdForTests::token_a_definition_id(), + ); +} + #[test] fn test_new_definition_lp_asymmetric_amounts() { let (post_states, chained_calls) = new_definition(