From 9a6ec0018b6e471ac027ad2c6180b29b70ddc37a Mon Sep 17 00:00:00 2001 From: Andrea Franz Date: Wed, 25 Feb 2026 13:11:54 +0100 Subject: [PATCH] feat(programs/amm): add swap exact output functionality --- program_methods/guest/src/bin/amm.rs | 21 +- programs/amm/core/src/lib.rs | 16 ++ programs/amm/src/swap.rs | 299 +++++++++++++++----- programs/amm/src/tests.rs | 400 ++++++++++++++++++++++++++- wallet/src/cli/programs/amm.rs | 55 ++++ wallet/src/program_facades/amm.rs | 97 ++++++- 6 files changed, 810 insertions(+), 78 deletions(-) diff --git a/program_methods/guest/src/bin/amm.rs b/program_methods/guest/src/bin/amm.rs index 748630d9..1c06389a 100644 --- a/program_methods/guest/src/bin/amm.rs +++ b/program_methods/guest/src/bin/amm.rs @@ -119,7 +119,7 @@ fn main() { } => { let [pool, vault_a, vault_b, user_holding_a, user_holding_b] = pre_states .try_into() - .expect("Transfer instruction requires exactly five accounts"); + .expect("Swap instruction requires exactly five accounts"); amm_program::swap::swap( pool, vault_a, @@ -131,6 +131,25 @@ fn main() { token_definition_id_in, ) } + Instruction::SwapExactOutput { + exact_amount_out, + max_amount_in, + token_definition_id_in, + } => { + let [pool, vault_a, vault_b, user_holding_a, user_holding_b] = pre_states + .try_into() + .expect("SwapExactOutput instruction requires exactly five accounts"); + amm_program::swap::swap_exact_output( + pool, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition_id_in, + ) + } }; ProgramOutput::new(instruction_words, pre_states_clone, post_states) diff --git a/programs/amm/core/src/lib.rs b/programs/amm/core/src/lib.rs index 85efd00d..5a9dda8d 100644 --- a/programs/amm/core/src/lib.rs +++ b/programs/amm/core/src/lib.rs @@ -73,6 +73,22 @@ pub enum Instruction { min_amount_out: u128, token_definition_id_in: AccountId, }, + + /// Swap tokens specifying the exact desired output amount, + /// while maintaining the Pool constant product. + /// + /// Required accounts: + /// - AMM Pool (initialized) + /// - Vault Holding Account for Token A (initialized) + /// - Vault Holding Account for Token B (initialized) + /// - User Holding Account for Token A + /// - User Holding Account for Token B Either User Holding Account for Token A or Token B is + /// authorized. + SwapExactOutput { + exact_amount_out: u128, + max_amount_in: u128, + token_definition_id_in: AccountId, + }, } #[derive(Clone, Default, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] diff --git a/programs/amm/src/swap.rs b/programs/amm/src/swap.rs index cb64f5eb..3cc84d35 100644 --- a/programs/amm/src/swap.rs +++ b/programs/amm/src/swap.rs @@ -4,6 +4,94 @@ use nssa_core::{ program::{AccountPostState, ChainedCall}, }; +/// Validates swap setup: checks pool is active, vaults match, and reserves are sufficient. +fn validate_swap_setup( + pool: &AccountWithMetadata, + vault_a: &AccountWithMetadata, + vault_b: &AccountWithMetadata, +) -> PoolDefinition { + let pool_def_data = PoolDefinition::try_from(&pool.account.data) + .expect("AMM Program expects a valid Pool Definition Account"); + + assert!(pool_def_data.active, "Pool is inactive"); + assert_eq!( + vault_a.account_id, pool_def_data.vault_a_id, + "Vault A was not provided" + ); + assert_eq!( + vault_b.account_id, pool_def_data.vault_b_id, + "Vault B was not provided" + ); + + let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) + .expect("AMM Program expects a valid Token Holding Account for Vault A"); + let token_core::TokenHolding::Fungible { + definition_id: _, + balance: vault_a_balance, + } = vault_a_token_holding + else { + panic!("AMM Program expects a valid Fungible Token Holding Account for Vault A"); + }; + + assert!( + vault_a_balance >= pool_def_data.reserve_a, + "Reserve for Token A exceeds vault balance" + ); + + let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) + .expect("AMM Program expects a valid Token Holding Account for Vault B"); + let token_core::TokenHolding::Fungible { + definition_id: _, + balance: vault_b_balance, + } = vault_b_token_holding + else { + panic!("AMM Program expects a valid Fungible Token Holding Account for Vault B"); + }; + + assert!( + vault_b_balance >= pool_def_data.reserve_b, + "Reserve for Token B exceeds vault balance" + ); + + pool_def_data +} + +/// Creates post-state and returns reserves after swap. +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[expect( + clippy::needless_pass_by_value, + reason = "consistent with codebase style" +)] +fn create_swap_post_states( + pool: AccountWithMetadata, + pool_def_data: PoolDefinition, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + deposit_a: u128, + withdraw_a: u128, + deposit_b: u128, + withdraw_b: u128, +) -> Vec { + let mut pool_post = pool.account; + let pool_post_definition = PoolDefinition { + reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, + reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, + ..pool_def_data + }; + + pool_post.data = Data::from(&pool_post_definition); + + vec![ + AccountPostState::new(pool_post), + AccountPostState::new(vault_a.account), + AccountPostState::new(vault_b.account), + AccountPostState::new(user_holding_a.account), + AccountPostState::new(user_holding_b.account), + ] +} + #[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] #[must_use] pub fn swap( @@ -16,51 +104,7 @@ pub fn swap( min_amount_out: u128, token_in_id: AccountId, ) -> (Vec, Vec) { - // Verify vaults are in fact vaults - let pool_def_data = PoolDefinition::try_from(&pool.account.data) - .expect("Swap: AMM Program expects a valid Pool Definition Account"); - - assert!(pool_def_data.active, "Pool is inactive"); - assert_eq!( - vault_a.account_id, pool_def_data.vault_a_id, - "Vault A was not provided" - ); - assert_eq!( - vault_b.account_id, pool_def_data.vault_b_id, - "Vault B was not provided" - ); - - // fetch pool reserves - // validates reserves is at least the vaults' balances - let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) - .expect("Swap: AMM Program expects a valid Token Holding Account for Vault A"); - let token_core::TokenHolding::Fungible { - definition_id: _, - balance: vault_a_balance, - } = vault_a_token_holding - else { - panic!("Swap: AMM Program expects a valid Fungible Token Holding Account for Vault A"); - }; - - assert!( - vault_a_balance >= pool_def_data.reserve_a, - "Reserve for Token A exceeds vault balance" - ); - - let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) - .expect("Swap: AMM Program expects a valid Token Holding Account for Vault B"); - let token_core::TokenHolding::Fungible { - definition_id: _, - balance: vault_b_balance, - } = vault_b_token_holding - else { - panic!("Swap: AMM Program expects a valid Fungible Token Holding Account for Vault B"); - }; - - assert!( - vault_b_balance >= pool_def_data.reserve_b, - "Reserve for Token B exceeds vault balance" - ); + let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = if token_in_id == pool_def_data.definition_token_a_id { @@ -95,23 +139,18 @@ pub fn swap( panic!("AccountId is not a token type for the pool"); }; - // Update pool account - let mut pool_post = pool.account; - let pool_post_definition = PoolDefinition { - reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, - reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, - ..pool_def_data - }; - - pool_post.data = Data::from(&pool_post_definition); - - let post_states = vec![ - AccountPostState::new(pool_post), - AccountPostState::new(vault_a.account), - AccountPostState::new(vault_b.account), - AccountPostState::new(user_holding_a.account), - AccountPostState::new(user_holding_b.account), - ]; + let post_states = create_swap_post_states( + pool, + pool_def_data, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + deposit_a, + withdraw_a, + deposit_b, + withdraw_b, + ); (post_states, chained_calls) } @@ -131,7 +170,9 @@ fn swap_logic( // Compute withdraw amount // Maintains pool constant product // k = pool_def_data.reserve_a * pool_def_data.reserve_b; - let withdraw_amount = (reserve_withdraw_vault_amount * swap_amount_in) + let withdraw_amount = reserve_withdraw_vault_amount + .checked_mul(swap_amount_in) + .expect("reserve * amount_in overflows u128") / (reserve_deposit_vault_amount + swap_amount_in); // Slippage check @@ -175,3 +216,135 @@ fn swap_logic( (chained_calls, swap_amount_in, withdraw_amount) } + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[must_use] +pub fn swap_exact_output( + pool: AccountWithMetadata, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + token_in_id: AccountId, +) -> (Vec, Vec) { + let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = + if token_in_id == pool_def_data.definition_token_a_id { + let (chained_calls, deposit_a, withdraw_b) = exact_output_swap_logic( + user_holding_a.clone(), + vault_a.clone(), + vault_b.clone(), + user_holding_b.clone(), + exact_amount_out, + max_amount_in, + pool_def_data.reserve_a, + pool_def_data.reserve_b, + pool.account_id, + ); + + (chained_calls, [deposit_a, 0], [0, withdraw_b]) + } else if token_in_id == pool_def_data.definition_token_b_id { + let (chained_calls, deposit_b, withdraw_a) = exact_output_swap_logic( + user_holding_b.clone(), + vault_b.clone(), + vault_a.clone(), + user_holding_a.clone(), + exact_amount_out, + max_amount_in, + pool_def_data.reserve_b, + pool_def_data.reserve_a, + pool.account_id, + ); + + (chained_calls, [0, withdraw_a], [deposit_b, 0]) + } else { + panic!("AccountId is not a token type for the pool"); + }; + + let post_states = create_swap_post_states( + pool, + pool_def_data, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + deposit_a, + withdraw_a, + deposit_b, + withdraw_b, + ); + + (post_states, chained_calls) +} + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +fn exact_output_swap_logic( + user_deposit: AccountWithMetadata, + vault_deposit: AccountWithMetadata, + vault_withdraw: AccountWithMetadata, + user_withdraw: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + reserve_deposit_vault_amount: u128, + reserve_withdraw_vault_amount: u128, + pool_id: AccountId, +) -> (Vec, u128, u128) { + // Guard: exact_amount_out must be nonzero + assert_ne!(exact_amount_out, 0, "Exact amount out must be nonzero"); + + // Guard: exact_amount_out must be less than reserve_withdraw_vault_amount + assert!( + exact_amount_out < reserve_withdraw_vault_amount, + "Exact amount out exceeds reserve" + ); + + // Compute deposit amount using ceiling division + // Formula: amount_in = ceil(reserve_in * exact_amount_out / (reserve_out - exact_amount_out)) + let deposit_amount = reserve_deposit_vault_amount + .checked_mul(exact_amount_out) + .expect("reserve * amount_out overflows u128") + .div_ceil(reserve_withdraw_vault_amount - exact_amount_out); + + // Slippage check + assert!( + deposit_amount <= max_amount_in, + "Required input exceeds maximum amount in" + ); + + let token_program_id = user_deposit.account.program_owner; + + let mut chained_calls = Vec::new(); + chained_calls.push(ChainedCall::new( + token_program_id, + vec![user_deposit, vault_deposit], + &token_core::Instruction::Transfer { + amount_to_transfer: deposit_amount, + }, + )); + + let mut vault_withdraw = vault_withdraw; + vault_withdraw.is_authorized = true; + + let pda_seed = compute_vault_pda_seed( + pool_id, + token_core::TokenHolding::try_from(&vault_withdraw.account.data) + .expect("Exact Output Swap Logic: AMM Program expects valid token data") + .definition_id(), + ); + + chained_calls.push( + ChainedCall::new( + token_program_id, + vec![vault_withdraw, user_withdraw], + &token_core::Instruction::Transfer { + amount_to_transfer: exact_amount_out, + }, + ) + .with_pda_seeds(vec![pda_seed]), + ); + + (chained_calls, deposit_amount, exact_amount_out) +} diff --git a/programs/amm/src/tests.rs b/programs/amm/src/tests.rs index 14638f9d..3d9566f2 100644 --- a/programs/amm/src/tests.rs +++ b/programs/amm/src/tests.rs @@ -14,7 +14,10 @@ use nssa_core::{ use token_core::{TokenDefinition, TokenHolding}; use crate::{ - add::add_liquidity, new_definition::new_definition, remove::remove_liquidity, swap::swap, + add::add_liquidity, + new_definition::new_definition, + remove::remove_liquidity, + swap::{swap, swap_exact_output}, }; const TOKEN_PROGRAM_ID: ProgramId = [15; 8]; @@ -153,6 +156,10 @@ impl BalanceForTests { 200 } + fn max_amount_in() -> u128 { + 166 + } + fn vault_a_add_successful() -> u128 { 1_400 } @@ -243,6 +250,74 @@ impl ChainedCallForTests { ) } + fn cc_swap_exact_output_token_a_test_1() -> ChainedCall { + let swap_amount: u128 = 498; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::vault_a_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + } + + fn cc_swap_exact_output_token_b_test_1() -> ChainedCall { + let swap_amount: u128 = 166; + + let mut vault_b_auth = AccountWithMetadataForTests::vault_b_init(); + vault_b_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_b_definition_id(), + )]) + } + + fn cc_swap_exact_output_token_a_test_2() -> ChainedCall { + let swap_amount: u128 = 285; + + let mut vault_a_auth = AccountWithMetadataForTests::vault_a_init(); + vault_a_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_a_auth, AccountWithMetadataForTests::user_holding_a()], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_a_definition_id(), + )]) + } + + fn cc_swap_exact_output_token_b_test_2() -> ChainedCall { + let swap_amount: u128 = 200; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::vault_b_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + } + fn cc_add_token_a() -> ChainedCall { ChainedCall::new( TOKEN_PROGRAM_ID, @@ -829,6 +904,54 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_swap_exact_output_test_1() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: 1498_u128, + reserve_b: 334_u128, + fees: 0_u128, + active: true, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + + fn pool_definition_swap_exact_output_test_2() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: BalanceForTests::vault_a_swap_test_2(), + reserve_b: BalanceForTests::vault_b_swap_test_2(), + fees: 0_u128, + active: true, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_add_zero_lp() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -2566,6 +2689,281 @@ fn call_swap_chained_call_successful_2() { ); } +#[should_panic(expected = "AccountId is not a token type for the pool")] +#[test] +fn call_swap_exact_output_incorrect_token_type() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_lp_definition_id(), + ); +} + +#[should_panic(expected = "Vault A was not provided")] +#[test] +fn call_swap_exact_output_vault_a_omitted() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_with_wrong_id(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Vault B was not provided")] +#[test] +fn call_swap_exact_output_vault_b_omitted() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_with_wrong_id(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Reserve for Token A exceeds vault balance")] +#[test] +fn call_swap_exact_output_reserves_vault_mismatch_1() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Reserve for Token B exceeds vault balance")] +#[test] +fn call_swap_exact_output_reserves_vault_mismatch_2() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Pool is inactive")] +#[test] +fn call_swap_exact_output_inactive() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_inactive(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Required input exceeds maximum amount in")] +#[test] +fn call_swap_exact_output_exceeds_max_in() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 166_u128, + 100_u128, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Exact amount out must be nonzero")] +#[test] +fn call_swap_exact_output_zero() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 0_u128, + 500_u128, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Exact amount out exceeds reserve")] +#[test] +fn call_swap_exact_output_exceeds_reserve() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::vault_b_reserve_init(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[test] +fn call_swap_exact_output_chained_call_successful() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::max_amount_in(), + BalanceForTests::vault_b_reserve_init(), + IdForTests::token_a_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_swap_exact_output_test_1().account + == *pool_post.account() + ); + + let chained_call_a = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_exact_output_token_a_test_1() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_exact_output_token_b_test_1() + ); +} + +#[test] +fn call_swap_exact_output_chained_call_successful_2() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 285, + 300, + IdForTests::token_b_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_swap_exact_output_test_2().account + == *pool_post.account() + ); + + let chained_call_a = chained_calls[1].clone(); + let chained_call_b = chained_calls[0].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_exact_output_token_a_test_2() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_exact_output_token_b_test_2() + ); +} + +// Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, +// making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes, +// so an attacker receives `exact_amount_out` tokens while paying nothing. +#[should_panic(expected = "reserve * amount_out overflows u128")] +#[test] +fn swap_exact_output_overflow_protection() { + // reserve_a chosen so that reserve_a * 2 overflows u128: + // (u128::MAX / 2 + 1) * 2 = u128::MAX + 1 → wraps to 0 + let large_reserve: u128 = u128::MAX / 2 + 1; + let reserve_b: u128 = 1_000; + + let pool = AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: 1, + reserve_a: large_reserve, + reserve_b, + fees: 0, + active: true, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + }; + + let vault_a = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: large_reserve, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::vault_a_id(), + }; + + let vault_b = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: reserve_b, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::vault_b_id(), + }; + + let _result = swap_exact_output( + pool, + vault_a, + vault_b, + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, // exact_amount_out: small, valid (< reserve_b) + 1, // max_amount_in: tiny — real deposit would be enormous, but + // overflow wraps it to 0, making 0 <= 1 pass silently + IdForTests::token_a_definition_id(), + ); +} + #[test] fn new_definition_lp_asymmetric_amounts() { let (post_states, chained_calls) = new_definition( diff --git a/wallet/src/cli/programs/amm.rs b/wallet/src/cli/programs/amm.rs index 7307569d..be87cfcc 100644 --- a/wallet/src/cli/programs/amm.rs +++ b/wallet/src/cli/programs/amm.rs @@ -52,6 +52,26 @@ pub enum AmmProgramAgnosticSubcommand { #[arg(long)] token_definition: String, }, + /// Swap specifying exact output amount. + /// + /// The account associated with swapping token must be owned. + /// + /// Only public execution allowed. + SwapExactOutput { + /// `user_holding_a` - valid 32 byte base58 string with privacy prefix. + #[arg(long)] + user_holding_a: String, + /// `user_holding_b` - valid 32 byte base58 string with privacy prefix. + #[arg(long)] + user_holding_b: String, + #[arg(long)] + exact_amount_out: u128, + #[arg(long)] + max_amount_in: u128, + /// `token_definition` - valid 32 byte base58 string WITHOUT privacy prefix. + #[arg(long)] + token_definition: String, + }, /// Add liquidity. /// /// `user_holding_a` and `user_holding_b` must be owned. @@ -185,6 +205,41 @@ impl WalletSubcommand for AmmProgramAgnosticSubcommand { } } } + Self::SwapExactOutput { + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition, + } => { + let (user_holding_a, user_holding_a_privacy) = + parse_addr_with_privacy_prefix(&user_holding_a)?; + let (user_holding_b, user_holding_b_privacy) = + parse_addr_with_privacy_prefix(&user_holding_b)?; + + let user_holding_a: AccountId = user_holding_a.parse()?; + let user_holding_b: AccountId = user_holding_b.parse()?; + + match (user_holding_a_privacy, user_holding_b_privacy) { + (AccountPrivacyKind::Public, AccountPrivacyKind::Public) => { + Amm(wallet_core) + .send_swap_exact_output( + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition.parse()?, + ) + .await?; + + Ok(SubcommandReturnValue::Empty) + } + _ => { + // ToDo: Implement after private multi-chain calls is available + anyhow::bail!("Only public execution allowed for Amm calls"); + } + } + } Self::AddLiquidity { user_holding_a, user_holding_b, diff --git a/wallet/src/program_facades/amm.rs b/wallet/src/program_facades/amm.rs index d68de7a5..d32558d6 100644 --- a/wallet/src/program_facades/amm.rs +++ b/wallet/src/program_facades/amm.rs @@ -168,34 +168,105 @@ impl Amm<'_> { user_holding_b, ]; - let account_id_auth; + let account_id_auth = if definition_token_a_id == token_definition_id_in { + user_holding_a + } else if definition_token_b_id == token_definition_id_in { + user_holding_b + } else { + return Err(ExecutionFailureKind::AccountDataError( + token_definition_id_in, + )); + }; - // Checking, which account are associated with TokenDefinition - let token_holder_acc_a = self + let nonces = self + .0 + .get_accounts_nonces(vec![account_id_auth]) + .await + .map_err(ExecutionFailureKind::SequencerError)?; + + let signing_key = self + .0 + .storage + .user_data + .get_pub_account_signing_key(account_id_auth) + .ok_or(ExecutionFailureKind::KeyNotFoundError)?; + + let message = nssa::public_transaction::Message::try_new( + program.id(), + account_ids, + nonces, + instruction, + ) + .unwrap(); + + let witness_set = + nssa::public_transaction::WitnessSet::for_message(&message, &[signing_key]); + + let tx = nssa::PublicTransaction::new(message, witness_set); + + Ok(self + .0 + .sequencer_client + .send_transaction(NSSATransaction::Public(tx)) + .await?) + } + + pub async fn send_swap_exact_output( + &self, + user_holding_a: AccountId, + user_holding_b: AccountId, + exact_amount_out: u128, + max_amount_in: u128, + token_definition_id_in: AccountId, + ) -> Result { + let instruction = amm_core::Instruction::SwapExactOutput { + exact_amount_out, + max_amount_in, + token_definition_id_in, + }; + let program = Program::amm(); + let amm_program_id = Program::amm().id(); + + let user_a_acc = self .0 .get_account_public(user_holding_a) .await .map_err(ExecutionFailureKind::SequencerError)?; - let token_holder_acc_b = self + let user_b_acc = self .0 .get_account_public(user_holding_b) .await .map_err(ExecutionFailureKind::SequencerError)?; - let token_holder_a = TokenHolding::try_from(&token_holder_acc_a.data) - .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_a))?; - let token_holder_b = TokenHolding::try_from(&token_holder_acc_b.data) - .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_b))?; + let definition_token_a_id = TokenHolding::try_from(&user_a_acc.data) + .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_a))? + .definition_id(); + let definition_token_b_id = TokenHolding::try_from(&user_b_acc.data) + .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_b))? + .definition_id(); - if token_holder_a.definition_id() == token_definition_id_in { - account_id_auth = user_holding_a; - } else if token_holder_b.definition_id() == token_definition_id_in { - account_id_auth = user_holding_b; + let amm_pool = + compute_pool_pda(amm_program_id, definition_token_a_id, definition_token_b_id); + let vault_holding_a = compute_vault_pda(amm_program_id, amm_pool, definition_token_a_id); + let vault_holding_b = compute_vault_pda(amm_program_id, amm_pool, definition_token_b_id); + + let account_ids = vec![ + amm_pool, + vault_holding_a, + vault_holding_b, + user_holding_a, + user_holding_b, + ]; + + let account_id_auth = if definition_token_a_id == token_definition_id_in { + user_holding_a + } else if definition_token_b_id == token_definition_id_in { + user_holding_b } else { return Err(ExecutionFailureKind::AccountDataError( token_definition_id_in, )); - } + }; let nonces = self .0