diff --git a/amm/src/remove.rs b/amm/src/remove.rs index ab1f7d8..a5a1042 100644 --- a/amm/src/remove.rs +++ b/amm/src/remove.rs @@ -60,7 +60,41 @@ pub fn remove_liquidity( "Minimum withdraw amount must be nonzero" ); - // 2. Compute withdrawal amounts + // 2. Read live vault balances and compute withdrawal amounts + let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) + .expect("Remove liquidity: AMM Program expects a valid Token Holding Account for Vault A"); + let token_core::TokenHolding::Fungible { + definition_id: _, + balance: vault_a_balance, + } = vault_a_token_holding + else { + panic!( + "Remove liquidity: AMM Program expects a valid Fungible Token Holding Account for Vault A" + ); + }; + + assert!( + vault_a_balance >= pool_def_data.reserve_a, + "Reserve for Token A exceeds vault balance" + ); + + let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) + .expect("Remove liquidity: AMM Program expects a valid Token Holding Account for Vault B"); + let token_core::TokenHolding::Fungible { + definition_id: _, + balance: vault_b_balance, + } = vault_b_token_holding + else { + panic!( + "Remove liquidity: AMM Program expects a valid Fungible Token Holding Account for Vault B" + ); + }; + + assert!( + vault_b_balance >= pool_def_data.reserve_b, + "Reserve for Token B exceeds vault balance" + ); + let user_holding_lp_data = token_core::TokenHolding::try_from(&user_holding_lp.account.data) .expect("Remove liquidity: AMM Program expects a valid Token Account for liquidity token"); let token_core::TokenHolding::Fungible { @@ -96,46 +130,58 @@ pub fn remove_liquidity( "Cannot remove locked minimum liquidity" ); - let withdraw_amount_a = pool_def_data + // Reserve accounting stays anchored to tracked reserves, while user withdrawals use the + // live vault balances so donated surplus is paid out proportionally. + let reserve_withdraw_amount_a = pool_def_data .reserve_a .checked_mul(remove_liquidity_amount) .expect("reserve_a * remove_liquidity_amount overflows u128") / pool_def_data.liquidity_pool_supply; - let withdraw_amount_b = pool_def_data + let reserve_withdraw_amount_b = pool_def_data .reserve_b .checked_mul(remove_liquidity_amount) .expect("reserve_b * remove_liquidity_amount overflows u128") / pool_def_data.liquidity_pool_supply; + let actual_withdraw_amount_a = vault_a_balance + .checked_mul(remove_liquidity_amount) + .expect("vault_a_balance * remove_liquidity_amount overflows u128") + / pool_def_data.liquidity_pool_supply; + let actual_withdraw_amount_b = vault_b_balance + .checked_mul(remove_liquidity_amount) + .expect("vault_b_balance * remove_liquidity_amount overflows u128") + / pool_def_data.liquidity_pool_supply; // 3. Validate and slippage check assert!( - withdraw_amount_a >= min_amount_to_remove_token_a, + actual_withdraw_amount_a >= min_amount_to_remove_token_a, "Insufficient minimal withdraw amount (Token A) provided for liquidity amount" ); assert!( - withdraw_amount_b >= min_amount_to_remove_token_b, + actual_withdraw_amount_b >= min_amount_to_remove_token_b, "Insufficient minimal withdraw amount (Token B) provided for liquidity amount" ); - // 4. Calculate LP to reduce cap by - let delta_lp: u128 = remove_liquidity_amount; + // 4. Burn exactly the requested LP amount. + let burn_amount_lp = remove_liquidity_amount; + let remaining_liquidity = pool_def_data + .liquidity_pool_supply + .checked_sub(burn_amount_lp) + .expect("liquidity_pool_supply - burn_amount_lp underflows"); + let active = remaining_liquidity != 0; // 5. Update pool account let mut pool_post = pool.account.clone(); let pool_post_definition = PoolDefinition { - liquidity_pool_supply: pool_def_data - .liquidity_pool_supply - .checked_sub(delta_lp) - .expect("liquidity_pool_supply - delta_lp underflows"), + liquidity_pool_supply: remaining_liquidity, reserve_a: pool_def_data .reserve_a - .checked_sub(withdraw_amount_a) - .expect("reserve_a - withdraw_amount_a underflows"), + .checked_sub(reserve_withdraw_amount_a) + .expect("reserve_a - reserve_withdraw_amount_a underflows"), reserve_b: pool_def_data .reserve_b - .checked_sub(withdraw_amount_b) - .expect("reserve_b - withdraw_amount_b underflows"), - active: true, + .checked_sub(reserve_withdraw_amount_b) + .expect("reserve_b - reserve_withdraw_amount_b underflows"), + active, ..pool_def_data.clone() }; @@ -148,7 +194,7 @@ pub fn remove_liquidity( token_program_id, vec![running_vault_a, user_holding_a.clone()], &token_core::Instruction::Transfer { - amount_to_transfer: withdraw_amount_a, + amount_to_transfer: actual_withdraw_amount_a, }, ) .with_pda_seeds(vec![compute_vault_pda_seed( @@ -160,7 +206,7 @@ pub fn remove_liquidity( token_program_id, vec![running_vault_b, user_holding_b.clone()], &token_core::Instruction::Transfer { - amount_to_transfer: withdraw_amount_b, + amount_to_transfer: actual_withdraw_amount_b, }, ) .with_pda_seeds(vec![compute_vault_pda_seed( @@ -174,7 +220,7 @@ pub fn remove_liquidity( token_program_id, vec![pool_definition_lp_auth, user_holding_lp.clone()], &token_core::Instruction::Burn { - amount_to_burn: delta_lp, + amount_to_burn: burn_amount_lp, }, ) .with_pda_seeds(vec![compute_liquidity_token_pda_seed(pool.account_id)]); diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 70b8434..2dc4789 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -208,6 +208,28 @@ impl BalanceForTests { fn remove_lp_supply_successful() -> u128 { BalanceForTests::lp_supply_init() - BalanceForTests::remove_amount_lp() } + + fn vault_a_balance_with_surplus() -> u128 { + BalanceForTests::vault_a_reserve_init() + (BalanceForTests::vault_a_reserve_init() / 10) + } + + fn vault_b_balance_with_surplus() -> u128 { + BalanceForTests::vault_b_reserve_init() + (BalanceForTests::vault_b_reserve_init() / 10) + } + + fn remove_actual_a_with_surplus() -> u128 { + (BalanceForTests::vault_a_balance_with_surplus() * BalanceForTests::remove_amount_lp()) + / BalanceForTests::lp_supply_init() + } + + fn remove_actual_b_with_surplus() -> u128 { + (BalanceForTests::vault_b_balance_with_surplus() * BalanceForTests::remove_amount_lp()) + / BalanceForTests::lp_supply_init() + } + + fn remove_min_amount_b_surplus() -> u128 { + BalanceForTests::remove_actual_b_with_surplus() + } } impl ChainedCallForTests { @@ -422,6 +444,40 @@ impl ChainedCallForTests { )]) } + fn cc_remove_token_a_with_surplus() -> ChainedCall { + let mut vault_a_auth = AccountWithMetadataForTests::vault_a_with_surplus(); + vault_a_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_a_auth, AccountWithMetadataForTests::user_holding_a()], + &token_core::Instruction::Transfer { + amount_to_transfer: BalanceForTests::remove_actual_a_with_surplus(), + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_a_definition_id(), + )]) + } + + fn cc_remove_token_b_with_surplus() -> ChainedCall { + let mut vault_b_auth = AccountWithMetadataForTests::vault_b_with_surplus(); + vault_b_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()], + &token_core::Instruction::Transfer { + amount_to_transfer: BalanceForTests::remove_actual_b_with_surplus(), + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_b_definition_id(), + )]) + } + fn cc_remove_pool_lp() -> ChainedCall { let mut pool_lp_auth = AccountWithMetadataForTests::pool_lp_init(); pool_lp_auth.is_authorized = true; @@ -619,6 +675,38 @@ impl AccountWithMetadataForTests { } } + fn vault_a_with_surplus() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: BalanceForTests::vault_a_balance_with_surplus(), + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::vault_a_id(), + } + } + + fn vault_b_with_surplus() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: BalanceForTests::vault_b_balance_with_surplus(), + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::vault_b_id(), + } + } + fn vault_a_init_high() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -1761,6 +1849,40 @@ fn test_call_remove_liquidity_min_bal_zero_2() { ); } +#[should_panic(expected = "Reserve for Token A exceeds vault balance")] +#[test] +fn test_call_remove_liquidity_reserves_vault_mismatch_1() { + let _post_states = remove_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(), + BalanceForTests::remove_min_amount_a(), + BalanceForTests::remove_min_amount_b(), + ); +} + +#[should_panic(expected = "Reserve for Token B exceeds vault balance")] +#[test] +fn test_call_remove_liquidity_reserves_vault_mismatch_2() { + let _post_states = remove_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(), + BalanceForTests::remove_min_amount_a(), + BalanceForTests::remove_min_amount_b(), + ); +} + #[test] fn test_call_remove_liquidity_chained_call_successful() { let (post_states, chained_calls) = remove_liquidity( @@ -1792,6 +1914,37 @@ fn test_call_remove_liquidity_chained_call_successful() { assert!(chained_call_lp == ChainedCallForTests::cc_remove_pool_lp()); } +#[test] +fn test_call_remove_liquidity_chained_call_with_vault_surplus_successful() { + let (post_states, chained_calls) = remove_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_with_surplus(), + AccountWithMetadataForTests::vault_b_with_surplus(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(), + BalanceForTests::remove_min_amount_a(), + BalanceForTests::remove_min_amount_b_surplus(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_remove_successful().account + == *pool_post.account() + ); + + let chained_call_lp = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + let chained_call_a = chained_calls[2].clone(); + + assert!(chained_call_a == ChainedCallForTests::cc_remove_token_a_with_surplus()); + assert!(chained_call_b == ChainedCallForTests::cc_remove_token_b_with_surplus()); + assert!(chained_call_lp == ChainedCallForTests::cc_remove_pool_lp()); +} + #[should_panic(expected = "Balances must be nonzero")] #[test] fn test_call_new_definition_with_zero_balance_1() { diff --git a/integration_tests/tests/amm.rs b/integration_tests/tests/amm.rs index b36da68..bd5c7de 100644 --- a/integration_tests/tests/amm.rs +++ b/integration_tests/tests/amm.rs @@ -144,6 +144,14 @@ impl Balances { 500 } + fn remove_min_a_with_surplus() -> u128 { + 1_100 + } + + fn remove_min_b_with_surplus() -> u128 { + 550 + } + fn add_min_lp() -> u128 { 1_000 } @@ -228,6 +236,22 @@ impl Balances { 2_000 } + fn vault_a_with_surplus() -> u128 { + 5_500 + } + + fn vault_b_with_surplus() -> u128 { + 2_750 + } + + fn vault_a_remove_with_surplus() -> u128 { + 4_400 + } + + fn vault_b_remove_with_surplus() -> u128 { + 2_200 + } + fn user_a_remove() -> u128 { 11_000 } @@ -236,6 +260,14 @@ impl Balances { 10_500 } + fn user_a_remove_with_surplus() -> u128 { + 11_100 + } + + fn user_b_remove_with_surplus() -> u128 { + 10_550 + } + fn user_lp_remove() -> u128 { 1_000 } @@ -369,6 +401,30 @@ impl Accounts { } } + fn vault_a_with_surplus() -> Account { + Account { + program_owner: Ids::token_program(), + balance: 0_u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: Ids::token_a_definition(), + balance: Balances::vault_a_with_surplus(), + }), + nonce: Nonce(0), + } + } + + fn vault_b_with_surplus() -> Account { + Account { + program_owner: Ids::token_program(), + balance: 0_u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: Ids::token_b_definition(), + balance: Balances::vault_b_with_surplus(), + }), + nonce: Nonce(0), + } + } + fn user_lp_holding() -> Account { Account { program_owner: Ids::token_program(), @@ -668,6 +724,30 @@ impl Accounts { } } + fn vault_a_remove_with_surplus() -> Account { + Account { + program_owner: Ids::token_program(), + balance: 0_u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: Ids::token_a_definition(), + balance: Balances::vault_a_remove_with_surplus(), + }), + nonce: Nonce(0), + } + } + + fn vault_b_remove_with_surplus() -> Account { + Account { + program_owner: Ids::token_program(), + balance: 0_u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: Ids::token_b_definition(), + balance: Balances::vault_b_remove_with_surplus(), + }), + nonce: Nonce(0), + } + } + fn user_a_holding_remove() -> Account { Account { program_owner: Ids::token_program(), @@ -692,6 +772,30 @@ impl Accounts { } } + fn user_a_holding_remove_with_surplus() -> Account { + Account { + program_owner: Ids::token_program(), + balance: 0_u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: Ids::token_a_definition(), + balance: Balances::user_a_remove_with_surplus(), + }), + nonce: Nonce(0), + } + } + + fn user_b_holding_remove_with_surplus() -> Account { + Account { + program_owner: Ids::token_program(), + balance: 0_u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: Ids::token_b_definition(), + balance: Balances::user_b_remove_with_surplus(), + }), + nonce: Nonce(0), + } + } + fn user_lp_holding_remove() -> Account { Account { program_owner: Ids::token_program(), @@ -1055,6 +1159,69 @@ fn amm_remove_liquidity_insufficient_user_lp_fails() { assert!(state.transition_from_public_transaction(&tx, 0).is_err()); } +#[test] +fn amm_remove_liquidity_with_surplus() { + let mut state = state_for_amm_tests(); + state.force_insert_account(Ids::vault_a(), Accounts::vault_a_with_surplus()); + state.force_insert_account(Ids::vault_b(), Accounts::vault_b_with_surplus()); + + let instruction = amm_core::Instruction::RemoveLiquidity { + remove_liquidity_amount: Balances::remove_lp(), + min_amount_to_remove_token_a: Balances::remove_min_a_with_surplus(), + min_amount_to_remove_token_b: Balances::remove_min_b_with_surplus(), + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::token_lp_definition(), + Ids::user_a(), + Ids::user_b(), + Ids::user_lp(), + ], + vec![Nonce(0)], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_lp()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); + + assert_eq!( + state.get_account_by_id(Ids::pool_definition()), + Accounts::pool_definition_remove() + ); + assert_eq!( + state.get_account_by_id(Ids::vault_a()), + Accounts::vault_a_remove_with_surplus() + ); + assert_eq!( + state.get_account_by_id(Ids::vault_b()), + Accounts::vault_b_remove_with_surplus() + ); + assert_eq!( + state.get_account_by_id(Ids::token_lp_definition()), + Accounts::token_lp_definition_remove() + ); + assert_eq!( + state.get_account_by_id(Ids::user_a()), + Accounts::user_a_holding_remove_with_surplus() + ); + assert_eq!( + state.get_account_by_id(Ids::user_b()), + Accounts::user_b_holding_remove_with_surplus() + ); + assert_eq!( + state.get_account_by_id(Ids::user_lp()), + Accounts::user_lp_holding_remove() + ); +} + #[test] fn amm_new_definition_inactive_initialized_pool_and_uninit_user_lp() { let mut state = state_for_amm_tests_with_new_def();