From 22978733d9297cf5e64ba4a83bc7098d63de796f Mon Sep 17 00:00:00 2001 From: Ricardo Guilherme Schmidt <3esmit@gmail.com> Date: Tue, 31 Mar 2026 23:15:10 -0300 Subject: [PATCH] feat(amm): apply trading fees to LP accounting - charge configured swap fees through fee-adjusted reserve updates - preserve accrued LP fees when adding new liquidity - add AMM coverage for fee accrual, fee payout, and anti-dilution --- amm/src/add.rs | 21 ++- amm/src/swap.rs | 52 ++++-- amm/src/tests.rs | 319 ++++++++++++++++++++++++++++----- integration_tests/tests/amm.rs | 302 ++++++++++++++++++++++++++++++- 4 files changed, 622 insertions(+), 72 deletions(-) diff --git a/amm/src/add.rs b/amm/src/add.rs index 91f5962..a2b8881 100644 --- a/amm/src/add.rs +++ b/amm/src/add.rs @@ -80,17 +80,16 @@ pub fn add_liquidity( "Vaults' balances must be at least the reserve amounts" ); - // Calculate actual_amounts - let ideal_a: u128 = pool_def_data - .reserve_a + // Quote deposits against live vault balances so newly added LPs do not + // receive a share of previously accrued fee surplus. + let ideal_a = vault_a_balance .checked_mul(max_amount_to_add_token_b) - .expect("reserve_a * max_amount_b overflows u128") - / pool_def_data.reserve_b; - let ideal_b: u128 = pool_def_data - .reserve_b + .expect("vault_a_balance * max_amount_to_add_token_b overflows u128") + / vault_b_balance; + let ideal_b = vault_b_balance .checked_mul(max_amount_to_add_token_a) - .expect("reserve_b * max_amount_a overflows u128") - / pool_def_data.reserve_a; + .expect("vault_b_balance * max_amount_to_add_token_a overflows u128") + / vault_a_balance; let actual_amount_a = if ideal_a > max_amount_to_add_token_a { max_amount_to_add_token_a @@ -122,12 +121,12 @@ pub fn add_liquidity( .liquidity_pool_supply .checked_mul(actual_amount_a) .expect("liquidity_pool_supply * actual_amount_a overflows u128") - / pool_def_data.reserve_a, + / vault_a_balance, pool_def_data .liquidity_pool_supply .checked_mul(actual_amount_b) .expect("liquidity_pool_supply * actual_amount_b overflows u128") - / pool_def_data.reserve_b, + / vault_b_balance, ); assert!(delta_lp != 0, "Payable LP must be nonzero"); diff --git a/amm/src/swap.rs b/amm/src/swap.rs index fe5c544..27cd41e 100644 --- a/amm/src/swap.rs +++ b/amm/src/swap.rs @@ -1,4 +1,4 @@ -use amm_core::assert_supported_fee_tier; +use amm_core::{assert_supported_fee_tier, FEE_BPS_DENOMINATOR}; pub use amm_core::{compute_liquidity_token_pda_seed, compute_vault_pda_seed, PoolDefinition}; use nssa_core::{ account::{AccountId, AccountWithMetadata, Data}, @@ -127,6 +127,7 @@ pub fn swap_exact_input( user_holding_b.clone(), swap_amount_in, min_amount_out, + pool_def_data.fees, pool_def_data.reserve_a, pool_def_data.reserve_b, pool.account_id, @@ -141,6 +142,7 @@ pub fn swap_exact_input( user_holding_a.clone(), swap_amount_in, min_amount_out, + pool_def_data.fees, pool_def_data.reserve_b, pool_def_data.reserve_a, pool.account_id, @@ -175,19 +177,31 @@ fn swap_logic( user_withdraw: AccountWithMetadata, swap_amount_in: u128, min_amount_out: u128, + fee_bps: u128, reserve_deposit_vault_amount: u128, reserve_withdraw_vault_amount: u128, pool_id: AccountId, ) -> (Vec, u128, u128) { - // Compute withdraw amount - // Maintains pool constant product - // k = pool_def_data.reserve_a * pool_def_data.reserve_b; + let fee_multiplier = FEE_BPS_DENOMINATOR + .checked_sub(fee_bps) + .expect("fee_bps exceeds fee denominator"); + let effective_amount_in = swap_amount_in + .checked_mul(fee_multiplier) + .expect("swap_amount_in * fee_multiplier overflows u128") + / FEE_BPS_DENOMINATOR; + assert!( + effective_amount_in != 0, + "Effective swap amount should be nonzero" + ); + + // Compute withdraw amount from fee-adjusted reserves while leaving the fee + // portion behind as vault surplus for LPs. let withdraw_amount = reserve_withdraw_vault_amount - .checked_mul(swap_amount_in) - .expect("reserve * amount_in overflows u128") + .checked_mul(effective_amount_in) + .expect("reserve_withdraw_vault_amount * effective_amount_in overflows u128") / reserve_deposit_vault_amount - .checked_add(swap_amount_in) - .expect("reserve + swap_amount_in overflows u128"); + .checked_add(effective_amount_in) + .expect("reserve_deposit_vault_amount + effective_amount_in overflows u128"); // Slippage check assert!( @@ -228,7 +242,7 @@ fn swap_logic( .with_pda_seeds(vec![pda_seed]), ); - (chained_calls, swap_amount_in, withdraw_amount) + (chained_calls, effective_amount_in, withdraw_amount) } #[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] @@ -254,6 +268,7 @@ pub fn swap_exact_output( user_holding_b.clone(), exact_amount_out, max_amount_in, + pool_def_data.fees, pool_def_data.reserve_a, pool_def_data.reserve_b, pool.account_id, @@ -268,6 +283,7 @@ pub fn swap_exact_output( user_holding_a.clone(), exact_amount_out, max_amount_in, + pool_def_data.fees, pool_def_data.reserve_b, pool_def_data.reserve_a, pool.account_id, @@ -302,6 +318,7 @@ fn exact_output_swap_logic( user_withdraw: AccountWithMetadata, exact_amount_out: u128, max_amount_in: u128, + fee_bps: u128, reserve_deposit_vault_amount: u128, reserve_withdraw_vault_amount: u128, pool_id: AccountId, @@ -317,10 +334,21 @@ fn exact_output_swap_logic( // Compute deposit amount using ceiling division // Formula: amount_in = ceil(reserve_in * exact_amount_out / (reserve_out - exact_amount_out)) - let deposit_amount = reserve_deposit_vault_amount + let effective_deposit_amount = reserve_deposit_vault_amount .checked_mul(exact_amount_out) .expect("reserve * amount_out overflows u128") - .div_ceil(reserve_withdraw_vault_amount - exact_amount_out); + .div_ceil( + reserve_withdraw_vault_amount + .checked_sub(exact_amount_out) + .expect("reserve_withdraw_vault_amount - exact_amount_out underflows"), + ); + let fee_multiplier = FEE_BPS_DENOMINATOR + .checked_sub(fee_bps) + .expect("fee_bps exceeds fee denominator"); + let deposit_amount = effective_deposit_amount + .checked_mul(FEE_BPS_DENOMINATOR) + .expect("effective_deposit_amount * fee denominator overflows u128") + .div_ceil(fee_multiplier); // Slippage check assert!( @@ -360,5 +388,5 @@ fn exact_output_swap_logic( .with_pda_seeds(vec![pda_seed]), ); - (chained_calls, deposit_amount, exact_amount_out) + (chained_calls, effective_deposit_amount, exact_amount_out) } diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 2dc4789..afcb918 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -5,7 +5,8 @@ use std::num::NonZero; use amm_core::{ compute_liquidity_token_pda, compute_liquidity_token_pda_seed, compute_lp_lock_holding_pda, compute_pool_pda, compute_vault_pda, compute_vault_pda_seed, PoolDefinition, FEE_TIER_BPS_1, - FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5, MINIMUM_LIQUIDITY, + FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5, FEE_BPS_DENOMINATOR, + MINIMUM_LIQUIDITY, }; use nssa_core::{ account::{Account, AccountId, AccountWithMetadata, Data, Nonce}, @@ -124,8 +125,20 @@ impl BalanceForTests { BalanceForTests::lp_supply_init() - MINIMUM_LIQUIDITY } + fn effective_swap_amount_in_a() -> u128 { + BalanceForTests::add_max_amount_a() + * (FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier()) + / FEE_BPS_DENOMINATOR + } + + fn effective_swap_amount_in_b() -> u128 { + BalanceForTests::add_max_amount_b() + * (FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier()) + / FEE_BPS_DENOMINATOR + } + fn vault_a_swap_test_1() -> u128 { - BalanceForTests::vault_a_reserve_init() + BalanceForTests::add_max_amount_a() + BalanceForTests::vault_a_reserve_init() + BalanceForTests::effective_swap_amount_in_a() } fn vault_a_swap_test_2() -> u128 { @@ -137,7 +150,7 @@ impl BalanceForTests { } fn vault_b_swap_test_2() -> u128 { - BalanceForTests::vault_b_reserve_init() + BalanceForTests::add_max_amount_b() + BalanceForTests::vault_b_reserve_init() + BalanceForTests::effective_swap_amount_in_b() } fn min_amount_out() -> u128 { @@ -169,6 +182,72 @@ impl BalanceForTests { 166 } + fn vault_a_balance_with_surplus() -> u128 { + BalanceForTests::vault_a_reserve_init() + (BalanceForTests::vault_a_reserve_init() / 10) + } + + fn vault_b_balance_with_surplus() -> u128 { + BalanceForTests::vault_b_reserve_init() + (BalanceForTests::vault_b_reserve_init() / 10) + } + + fn exact_output_effective_amount_in_token_a() -> u128 { + BalanceForTests::vault_a_reserve_init() + .checked_mul(BalanceForTests::max_amount_in()) + .expect("vault_a_reserve_init * max_amount_in overflows u128") + .div_ceil(BalanceForTests::vault_b_reserve_init() - BalanceForTests::max_amount_in()) + } + + fn exact_output_effective_amount_in_token_b() -> u128 { + BalanceForTests::vault_b_reserve_init() + .checked_mul(285) + .expect("vault_b_reserve_init * exact_amount_out overflows u128") + .div_ceil(BalanceForTests::vault_a_reserve_init() - 285) + } + + fn exact_output_deposit_amount_token_a() -> u128 { + BalanceForTests::exact_output_effective_amount_in_token_a() + .checked_mul(FEE_BPS_DENOMINATOR) + .expect("effective amount in * fee denominator overflows u128") + .div_ceil(FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier()) + } + + fn exact_output_deposit_amount_token_b() -> u128 { + BalanceForTests::exact_output_effective_amount_in_token_b() + .checked_mul(FEE_BPS_DENOMINATOR) + .expect("effective amount in * fee denominator overflows u128") + .div_ceil(FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier()) + } + + fn reserve_a_add_with_surplus() -> u128 { + BalanceForTests::vault_a_reserve_init() + BalanceForTests::add_successful_amount_a_with_surplus() + } + + fn reserve_b_add_with_surplus() -> u128 { + BalanceForTests::vault_b_reserve_init() + BalanceForTests::add_successful_amount_b_with_surplus() + } + + fn add_successful_amount_a_with_surplus() -> u128 { + (BalanceForTests::vault_a_balance_with_surplus() * BalanceForTests::add_max_amount_b()) + / BalanceForTests::vault_b_balance_with_surplus() + } + + fn add_successful_amount_b_with_surplus() -> u128 { + BalanceForTests::add_max_amount_b() + } + + fn lp_supply_with_surplus() -> u128 { + BalanceForTests::lp_supply_init() + BalanceForTests::lp_mint_with_surplus() + } + + fn lp_mint_with_surplus() -> u128 { + std::cmp::min( + BalanceForTests::lp_supply_init() * BalanceForTests::add_successful_amount_a_with_surplus() + / BalanceForTests::vault_a_balance_with_surplus(), + BalanceForTests::lp_supply_init() * BalanceForTests::add_successful_amount_b_with_surplus() + / BalanceForTests::vault_b_balance_with_surplus(), + ) + } + fn vault_a_remove_successful() -> u128 { BalanceForTests::vault_a_reserve_init() - BalanceForTests::remove_actual_a_successful() } @@ -178,13 +257,15 @@ impl BalanceForTests { } fn swap_amount_out_b() -> u128 { - (BalanceForTests::vault_b_reserve_init() * BalanceForTests::add_max_amount_a()) - / (BalanceForTests::vault_a_reserve_init() + BalanceForTests::add_max_amount_a()) + (BalanceForTests::vault_b_reserve_init() * BalanceForTests::effective_swap_amount_in_a()) + / (BalanceForTests::vault_a_reserve_init() + + BalanceForTests::effective_swap_amount_in_a()) } fn swap_amount_out_a() -> u128 { - (BalanceForTests::vault_a_reserve_init() * BalanceForTests::add_max_amount_b()) - / (BalanceForTests::vault_b_reserve_init() + BalanceForTests::add_max_amount_b()) + (BalanceForTests::vault_a_reserve_init() * BalanceForTests::effective_swap_amount_in_b()) + / (BalanceForTests::vault_b_reserve_init() + + BalanceForTests::effective_swap_amount_in_b()) } fn add_delta_lp_successful() -> u128 { @@ -209,14 +290,6 @@ impl BalanceForTests { BalanceForTests::lp_supply_init() - BalanceForTests::remove_amount_lp() } - fn vault_a_balance_with_surplus() -> u128 { - BalanceForTests::vault_a_reserve_init() + (BalanceForTests::vault_a_reserve_init() / 10) - } - - fn vault_b_balance_with_surplus() -> u128 { - BalanceForTests::vault_b_reserve_init() + (BalanceForTests::vault_b_reserve_init() / 10) - } - fn remove_actual_a_with_surplus() -> u128 { (BalanceForTests::vault_a_balance_with_surplus() * BalanceForTests::remove_amount_lp()) / BalanceForTests::lp_supply_init() @@ -298,7 +371,7 @@ impl ChainedCallForTests { } fn cc_swap_exact_output_token_a_test_1() -> ChainedCall { - let swap_amount: u128 = 498; + let swap_amount = BalanceForTests::exact_output_deposit_amount_token_a(); ChainedCall::new( TOKEN_PROGRAM_ID, @@ -351,7 +424,7 @@ impl ChainedCallForTests { } fn cc_swap_exact_output_token_b_test_2() -> ChainedCall { - let swap_amount: u128 = 200; + let swap_amount = BalanceForTests::exact_output_deposit_amount_token_b(); ChainedCall::new( TOKEN_PROGRAM_ID, @@ -410,6 +483,51 @@ impl ChainedCallForTests { )]) } + fn cc_add_token_a_with_surplus() -> ChainedCall { + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::vault_a_with_surplus(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: BalanceForTests::add_successful_amount_a_with_surplus(), + }, + ) + } + + fn cc_add_token_b_with_surplus() -> ChainedCall { + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::vault_b_with_surplus(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: BalanceForTests::add_successful_amount_b_with_surplus(), + }, + ) + } + + fn cc_add_pool_lp_with_surplus() -> ChainedCall { + let mut pool_lp_auth = AccountWithMetadataForTests::pool_lp_init(); + pool_lp_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + pool_lp_auth, + AccountWithMetadataForTests::user_holding_lp_init(), + ], + &token_core::Instruction::Mint { + amount_to_mint: BalanceForTests::lp_mint_with_surplus(), + }, + ) + .with_pda_seeds(vec![compute_liquidity_token_pda_seed( + IdForTests::pool_definition_id(), + )]) + } + fn cc_remove_token_a() -> ChainedCall { let mut vault_a_auth = AccountWithMetadataForTests::vault_a_init(); vault_a_auth.is_authorized = true; @@ -444,6 +562,25 @@ impl ChainedCallForTests { )]) } + fn cc_remove_pool_lp() -> ChainedCall { + let mut pool_lp_auth = AccountWithMetadataForTests::pool_lp_init(); + pool_lp_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + pool_lp_auth, + AccountWithMetadataForTests::user_holding_lp_init(), + ], + &token_core::Instruction::Burn { + amount_to_burn: BalanceForTests::remove_amount_lp(), + }, + ) + .with_pda_seeds(vec![compute_liquidity_token_pda_seed( + IdForTests::pool_definition_id(), + )]) + } + fn cc_remove_token_a_with_surplus() -> ChainedCall { let mut vault_a_auth = AccountWithMetadataForTests::vault_a_with_surplus(); vault_a_auth.is_authorized = true; @@ -478,25 +615,6 @@ impl ChainedCallForTests { )]) } - fn cc_remove_pool_lp() -> ChainedCall { - let mut pool_lp_auth = AccountWithMetadataForTests::pool_lp_init(); - pool_lp_auth.is_authorized = true; - - ChainedCall::new( - TOKEN_PROGRAM_ID, - vec![ - pool_lp_auth, - AccountWithMetadataForTests::user_holding_lp_init(), - ], - &token_core::Instruction::Burn { - amount_to_burn: BalanceForTests::remove_amount_lp(), - }, - ) - .with_pda_seeds(vec![compute_liquidity_token_pda_seed( - IdForTests::pool_definition_id(), - )]) - } - fn cc_new_definition_token_a() -> ChainedCall { ChainedCall::new( TOKEN_PROGRAM_ID, @@ -1156,8 +1274,10 @@ impl AccountWithMetadataForTests { vault_b_id: IdForTests::vault_b_id(), liquidity_pool_id: IdForTests::token_lp_definition_id(), liquidity_pool_supply: BalanceForTests::lp_supply_init(), - reserve_a: 1498_u128, - reserve_b: 334_u128, + reserve_a: BalanceForTests::vault_a_reserve_init() + + BalanceForTests::exact_output_effective_amount_in_token_a(), + reserve_b: BalanceForTests::vault_b_reserve_init() + - BalanceForTests::max_amount_in(), fees: BalanceForTests::fee_tier(), active: true, }), @@ -1180,8 +1300,9 @@ impl AccountWithMetadataForTests { vault_b_id: IdForTests::vault_b_id(), liquidity_pool_id: IdForTests::token_lp_definition_id(), liquidity_pool_supply: BalanceForTests::lp_supply_init(), - reserve_a: 715_u128, - reserve_b: 700_u128, + reserve_a: BalanceForTests::vault_a_reserve_init() - 285, + reserve_b: BalanceForTests::vault_b_reserve_init() + + BalanceForTests::exact_output_effective_amount_in_token_b(), fees: BalanceForTests::fee_tier(), active: true, }), @@ -1240,6 +1361,54 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_add_with_surplus_successful() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_with_surplus(), + reserve_a: BalanceForTests::reserve_a_add_with_surplus(), + reserve_b: BalanceForTests::reserve_b_add_with_surplus(), + fees: BalanceForTests::fee_tier(), + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + + fn pool_definition_init_low_balances() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::vault_a_reserve_low(), + reserve_a: BalanceForTests::vault_a_reserve_low(), + reserve_b: BalanceForTests::vault_b_reserve_low(), + fees: BalanceForTests::fee_tier(), + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_remove_successful() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -1669,6 +1838,46 @@ fn test_call_add_liquidity_chained_call_successsful() { assert!(chained_call_lp == ChainedCallForTests::cc_add_pool_lp()); } +#[test] +fn test_call_add_liquidity_with_fee_surplus_preserves_existing_lp_value() { + let (post_states, chained_calls) = add_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_with_surplus(), + AccountWithMetadataForTests::vault_b_with_surplus(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::add_max_amount_b(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_add_with_surplus_successful().account + == *pool_post.account() + ); + + let chained_call_lp = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + let chained_call_a = chained_calls[2].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_add_token_a_with_surplus() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_add_token_b_with_surplus() + ); + assert_eq!( + chained_call_lp, + ChainedCallForTests::cc_add_pool_lp_with_surplus() + ); +} + #[should_panic(expected = "Vault A was not provided")] #[test] fn test_call_remove_liquidity_vault_a_omitted() { @@ -2293,6 +2502,36 @@ fn test_call_swap_below_min_out() { ); } +#[should_panic(expected = "Effective swap amount should be nonzero")] +#[test] +fn test_call_swap_effective_amount_zero() { + let _post_states = swap_exact_input( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 1, + 0, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Withdraw amount should be nonzero")] +#[test] +fn test_call_swap_output_rounds_to_zero() { + let _post_states = swap_exact_input( + AccountWithMetadataForTests::pool_definition_init_low_balances(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, + 0, + IdForTests::token_a_definition_id(), + ); +} + #[test] fn test_call_swap_chained_call_successful_1() { let (post_states, chained_calls) = swap_exact_input( diff --git a/integration_tests/tests/amm.rs b/integration_tests/tests/amm.rs index bd5c7de..cb68885 100644 --- a/integration_tests/tests/amm.rs +++ b/integration_tests/tests/amm.rs @@ -172,8 +172,16 @@ impl Balances { 200 } + fn reserve_a_swap_1() -> u128 { + 3_575 + } + + fn reserve_b_swap_1() -> u128 { + 3_497 + } + fn vault_a_swap_1() -> u128 { - 3_572 + 3_575 } fn vault_b_swap_1() -> u128 { @@ -181,19 +189,27 @@ impl Balances { } fn user_a_swap_1() -> u128 { - 11_428 + 11_425 } fn user_b_swap_1() -> u128 { 9_000 } + fn reserve_a_swap_2() -> u128 { + 5_997 + } + + fn reserve_b_swap_2() -> u128 { + 2_085 + } + fn vault_a_swap_2() -> u128 { 6_000 } fn vault_b_swap_2() -> u128 { - 2_084 + 2_085 } fn user_a_swap_2() -> u128 { @@ -201,7 +217,7 @@ impl Balances { } fn user_b_swap_2() -> u128 { - 10_416 + 10_415 } fn vault_a_add() -> u128 { @@ -462,8 +478,8 @@ impl Accounts { vault_b_id: Ids::vault_b(), liquidity_pool_id: Ids::token_lp_definition(), liquidity_pool_supply: Balances::pool_lp_supply_init(), - reserve_a: Balances::vault_a_swap_1(), - reserve_b: Balances::vault_b_swap_1(), + reserve_a: Balances::reserve_a_swap_1(), + reserve_b: Balances::reserve_b_swap_1(), fees: Balances::fee_tier(), active: true, }), @@ -530,8 +546,8 @@ impl Accounts { vault_b_id: Ids::vault_b(), liquidity_pool_id: Ids::token_lp_definition(), liquidity_pool_supply: Balances::pool_lp_supply_init(), - reserve_a: Balances::vault_a_swap_2(), - reserve_b: Balances::vault_b_swap_2(), + reserve_a: Balances::reserve_a_swap_2(), + reserve_b: Balances::reserve_b_swap_2(), fees: Balances::fee_tier(), active: true, }), @@ -1029,6 +1045,11 @@ fn state_for_amm_tests_with_new_def() -> V03State { state } +fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), NssaError> { +fn current_nonce(state: &V03State, account_id: AccountId) -> Nonce { + state.get_account_by_id(account_id).nonce +} + fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), NssaError> { let instruction = amm_core::Instruction::NewDefinition { token_a_amount: Balances::vault_a_init(), @@ -1049,7 +1070,10 @@ fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), Ns Ids::user_b(), Ids::user_lp(), ], - vec![Nonce(0), Nonce(0)], + vec![ + current_nonce(state, Ids::user_a()), + current_nonce(state, Ids::user_b()), + ], instruction, ) .unwrap(); @@ -1065,6 +1089,163 @@ fn execute_new_definition(state: &mut V03State, fees: u128) { try_execute_new_definition(state, fees).unwrap(); } +fn execute_swap_a_to_b(state: &mut V03State, swap_amount_in: u128, min_amount_out: u128) { + let instruction = amm_core::Instruction::SwapExactInput { + swap_amount_in, + min_amount_out, + token_definition_id_in: Ids::token_a_definition(), + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::user_a(), + Ids::user_b(), + ], + vec![current_nonce(state, Ids::user_a())], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_a()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn execute_swap_b_to_a(state: &mut V03State, swap_amount_in: u128, min_amount_out: u128) { + let instruction = amm_core::Instruction::SwapExactInput { + swap_amount_in, + min_amount_out, + token_definition_id_in: Ids::token_b_definition(), + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::user_a(), + Ids::user_b(), + ], + vec![current_nonce(state, Ids::user_b())], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_b()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn execute_add_liquidity( + state: &mut V03State, + min_amount_liquidity: u128, + max_amount_to_add_token_a: u128, + max_amount_to_add_token_b: u128, +) { + let instruction = amm_core::Instruction::AddLiquidity { + min_amount_liquidity, + max_amount_to_add_token_a, + max_amount_to_add_token_b, + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::token_lp_definition(), + Ids::user_a(), + Ids::user_b(), + Ids::user_lp(), + ], + vec![ + current_nonce(state, Ids::user_a()), + current_nonce(state, Ids::user_b()), + ], + instruction, + ) + .unwrap(); + + let witness_set = + public_transaction::WitnessSet::for_message(&message, &[&Keys::user_a(), &Keys::user_b()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn execute_remove_liquidity( + state: &mut V03State, + remove_liquidity_amount: u128, + min_amount_to_remove_token_a: u128, + min_amount_to_remove_token_b: u128, +) { + let instruction = amm_core::Instruction::RemoveLiquidity { + remove_liquidity_amount, + min_amount_to_remove_token_a, + min_amount_to_remove_token_b, + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::token_lp_definition(), + Ids::user_a(), + Ids::user_b(), + Ids::user_lp(), + ], + vec![current_nonce(state, Ids::user_lp())], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_lp()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn fungible_balance(account: &Account) -> u128 { + let holding = TokenHolding::try_from(&account.data).expect("expected token holding"); + let TokenHolding::Fungible { + definition_id: _, + balance, + } = holding + else { + panic!("expected fungible token holding") + }; + + balance +} + +fn pool_definition(account: &Account) -> PoolDefinition { + PoolDefinition::try_from(&account.data).expect("expected pool definition") +} + +fn fungible_total_supply(account: &Account) -> u128 { + let definition = TokenDefinition::try_from(&account.data).expect("expected token definition"); + let TokenDefinition::Fungible { + name: _, + total_supply, + metadata_id: _, + } = definition + else { + panic!("expected fungible token definition") + }; + + total_supply +} + #[test] fn amm_remove_liquidity() { let mut state = state_for_amm_tests(); @@ -1588,3 +1769,106 @@ fn amm_swap_a_to_b() { Accounts::user_b_holding_swap_2() ); } + +#[test] +fn amm_fee_accumulates_across_multiple_swaps_and_pays_out_on_remove() { + let mut state = state_for_amm_tests(); + + execute_swap_a_to_b(&mut state, 1_000, 200); + execute_swap_b_to_a(&mut state, 1_000, 200); + + let pool_before_remove = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + assert_eq!(pool_before_remove.reserve_a, 4_058); + assert_eq!(pool_before_remove.reserve_b, 3_082); + assert_eq!(pool_before_remove.fees, Balances::fee_tier()); + + let vault_a_before_remove = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_before_remove = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + assert_eq!(vault_a_before_remove, 4_061); + assert_eq!(vault_b_before_remove, 3_085); + assert_eq!(vault_a_before_remove - pool_before_remove.reserve_a, 3); + assert_eq!(vault_b_before_remove - pool_before_remove.reserve_b, 3); + + execute_remove_liquidity(&mut state, 1_000, 812, 617); + + let pool_after_remove = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + assert_eq!(pool_after_remove.reserve_a, 3_247); + assert_eq!(pool_after_remove.reserve_b, 2_466); + assert_eq!(pool_after_remove.liquidity_pool_supply, 4_000); + + let vault_a_after_remove = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_after_remove = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + assert_eq!(vault_a_after_remove, 3_249); + assert_eq!(vault_b_after_remove, 2_468); + assert_eq!(vault_a_after_remove - pool_after_remove.reserve_a, 2); + assert_eq!(vault_b_after_remove - pool_after_remove.reserve_b, 2); + + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_a())), + 11_751 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_b())), + 10_032 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_lp())), + 1_000 + ); + assert_eq!( + fungible_total_supply(&state.get_account_by_id(Ids::token_lp_definition())), + 4_000 + ); +} + +#[test] +fn amm_add_liquidity_after_fee_accrual_preserves_surplus() { + let mut state = state_for_amm_tests(); + + execute_swap_a_to_b(&mut state, 1_000, 200); + execute_swap_b_to_a(&mut state, 1_000, 200); + execute_swap_a_to_b(&mut state, 1_000, 200); + execute_swap_b_to_a(&mut state, 1_000, 200); + + let pool_before_add = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + let vault_a_before_add = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_before_add = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + + assert_eq!(pool_before_add.reserve_a, 3_604); + assert_eq!(pool_before_add.reserve_b, 3_472); + assert_eq!(vault_a_before_add, 3_610); + assert_eq!(vault_b_before_add, 3_478); + assert_eq!(vault_a_before_add - pool_before_add.reserve_a, 6); + assert_eq!(vault_b_before_add - pool_before_add.reserve_b, 6); + + execute_add_liquidity(&mut state, 1_436, 2_000, 1_000); + + let pool_after_add = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + let vault_a_after_add = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_after_add = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + + assert_eq!(pool_after_add.reserve_a, 4_641); + assert_eq!(pool_after_add.reserve_b, 4_472); + assert_eq!(pool_after_add.liquidity_pool_supply, 6_436); + assert_eq!(vault_a_after_add, 4_647); + assert_eq!(vault_b_after_add, 4_478); + assert_eq!(vault_a_after_add - pool_after_add.reserve_a, 6); + assert_eq!(vault_b_after_add - pool_after_add.reserve_b, 6); + + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_a())), + 10_353 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_b())), + 8_022 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_lp())), + 3_436 + ); + assert_eq!( + fungible_total_supply(&state.get_account_by_id(Ids::token_lp_definition())), + 6_436 + ); +}