feat: enhance remove liquidity function to handle surplus balances

This commit is contained in:
Ricardo Guilherme Schmidt 2026-03-31 18:39:12 -03:00
parent 9824cd8f90
commit de23a2a4d4
No known key found for this signature in database
GPG Key ID: 1396EA17DE132FFE
3 changed files with 385 additions and 19 deletions

View File

@ -60,7 +60,41 @@ pub fn remove_liquidity(
"Minimum withdraw amount must be nonzero"
);
// 2. Compute withdrawal amounts
// 2. Read live vault balances and compute withdrawal amounts
let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data)
.expect("Remove liquidity: AMM Program expects a valid Token Holding Account for Vault A");
let token_core::TokenHolding::Fungible {
definition_id: _,
balance: vault_a_balance,
} = vault_a_token_holding
else {
panic!(
"Remove liquidity: AMM Program expects a valid Fungible Token Holding Account for Vault A"
);
};
assert!(
vault_a_balance >= pool_def_data.reserve_a,
"Reserve for Token A exceeds vault balance"
);
let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data)
.expect("Remove liquidity: AMM Program expects a valid Token Holding Account for Vault B");
let token_core::TokenHolding::Fungible {
definition_id: _,
balance: vault_b_balance,
} = vault_b_token_holding
else {
panic!(
"Remove liquidity: AMM Program expects a valid Fungible Token Holding Account for Vault B"
);
};
assert!(
vault_b_balance >= pool_def_data.reserve_b,
"Reserve for Token B exceeds vault balance"
);
let user_holding_lp_data = token_core::TokenHolding::try_from(&user_holding_lp.account.data)
.expect("Remove liquidity: AMM Program expects a valid Token Account for liquidity token");
let token_core::TokenHolding::Fungible {
@ -96,46 +130,58 @@ pub fn remove_liquidity(
"Cannot remove locked minimum liquidity"
);
let withdraw_amount_a = pool_def_data
// Reserve accounting stays anchored to tracked reserves, while user withdrawals use the
// live vault balances so donated surplus is paid out proportionally.
let reserve_withdraw_amount_a = pool_def_data
.reserve_a
.checked_mul(remove_liquidity_amount)
.expect("reserve_a * remove_liquidity_amount overflows u128")
/ pool_def_data.liquidity_pool_supply;
let withdraw_amount_b = pool_def_data
let reserve_withdraw_amount_b = pool_def_data
.reserve_b
.checked_mul(remove_liquidity_amount)
.expect("reserve_b * remove_liquidity_amount overflows u128")
/ pool_def_data.liquidity_pool_supply;
let actual_withdraw_amount_a = vault_a_balance
.checked_mul(remove_liquidity_amount)
.expect("vault_a_balance * remove_liquidity_amount overflows u128")
/ pool_def_data.liquidity_pool_supply;
let actual_withdraw_amount_b = vault_b_balance
.checked_mul(remove_liquidity_amount)
.expect("vault_b_balance * remove_liquidity_amount overflows u128")
/ pool_def_data.liquidity_pool_supply;
// 3. Validate and slippage check
assert!(
withdraw_amount_a >= min_amount_to_remove_token_a,
actual_withdraw_amount_a >= min_amount_to_remove_token_a,
"Insufficient minimal withdraw amount (Token A) provided for liquidity amount"
);
assert!(
withdraw_amount_b >= min_amount_to_remove_token_b,
actual_withdraw_amount_b >= min_amount_to_remove_token_b,
"Insufficient minimal withdraw amount (Token B) provided for liquidity amount"
);
// 4. Calculate LP to reduce cap by
let delta_lp: u128 = remove_liquidity_amount;
// 4. Burn exactly the requested LP amount.
let burn_amount_lp = remove_liquidity_amount;
let remaining_liquidity = pool_def_data
.liquidity_pool_supply
.checked_sub(burn_amount_lp)
.expect("liquidity_pool_supply - burn_amount_lp underflows");
let active = remaining_liquidity != 0;
// 5. Update pool account
let mut pool_post = pool.account.clone();
let pool_post_definition = PoolDefinition {
liquidity_pool_supply: pool_def_data
.liquidity_pool_supply
.checked_sub(delta_lp)
.expect("liquidity_pool_supply - delta_lp underflows"),
liquidity_pool_supply: remaining_liquidity,
reserve_a: pool_def_data
.reserve_a
.checked_sub(withdraw_amount_a)
.expect("reserve_a - withdraw_amount_a underflows"),
.checked_sub(reserve_withdraw_amount_a)
.expect("reserve_a - reserve_withdraw_amount_a underflows"),
reserve_b: pool_def_data
.reserve_b
.checked_sub(withdraw_amount_b)
.expect("reserve_b - withdraw_amount_b underflows"),
active: true,
.checked_sub(reserve_withdraw_amount_b)
.expect("reserve_b - reserve_withdraw_amount_b underflows"),
active,
..pool_def_data.clone()
};
@ -148,7 +194,7 @@ pub fn remove_liquidity(
token_program_id,
vec![running_vault_a, user_holding_a.clone()],
&token_core::Instruction::Transfer {
amount_to_transfer: withdraw_amount_a,
amount_to_transfer: actual_withdraw_amount_a,
},
)
.with_pda_seeds(vec![compute_vault_pda_seed(
@ -160,7 +206,7 @@ pub fn remove_liquidity(
token_program_id,
vec![running_vault_b, user_holding_b.clone()],
&token_core::Instruction::Transfer {
amount_to_transfer: withdraw_amount_b,
amount_to_transfer: actual_withdraw_amount_b,
},
)
.with_pda_seeds(vec![compute_vault_pda_seed(
@ -174,7 +220,7 @@ pub fn remove_liquidity(
token_program_id,
vec![pool_definition_lp_auth, user_holding_lp.clone()],
&token_core::Instruction::Burn {
amount_to_burn: delta_lp,
amount_to_burn: burn_amount_lp,
},
)
.with_pda_seeds(vec![compute_liquidity_token_pda_seed(pool.account_id)]);

View File

@ -208,6 +208,28 @@ impl BalanceForTests {
fn remove_lp_supply_successful() -> u128 {
BalanceForTests::lp_supply_init() - BalanceForTests::remove_amount_lp()
}
fn vault_a_balance_with_surplus() -> u128 {
BalanceForTests::vault_a_reserve_init() + (BalanceForTests::vault_a_reserve_init() / 10)
}
fn vault_b_balance_with_surplus() -> u128 {
BalanceForTests::vault_b_reserve_init() + (BalanceForTests::vault_b_reserve_init() / 10)
}
fn remove_actual_a_with_surplus() -> u128 {
(BalanceForTests::vault_a_balance_with_surplus() * BalanceForTests::remove_amount_lp())
/ BalanceForTests::lp_supply_init()
}
fn remove_actual_b_with_surplus() -> u128 {
(BalanceForTests::vault_b_balance_with_surplus() * BalanceForTests::remove_amount_lp())
/ BalanceForTests::lp_supply_init()
}
fn remove_min_amount_b_surplus() -> u128 {
BalanceForTests::remove_actual_b_with_surplus()
}
}
impl ChainedCallForTests {
@ -422,6 +444,40 @@ impl ChainedCallForTests {
)])
}
fn cc_remove_token_a_with_surplus() -> ChainedCall {
let mut vault_a_auth = AccountWithMetadataForTests::vault_a_with_surplus();
vault_a_auth.is_authorized = true;
ChainedCall::new(
TOKEN_PROGRAM_ID,
vec![vault_a_auth, AccountWithMetadataForTests::user_holding_a()],
&token_core::Instruction::Transfer {
amount_to_transfer: BalanceForTests::remove_actual_a_with_surplus(),
},
)
.with_pda_seeds(vec![compute_vault_pda_seed(
IdForTests::pool_definition_id(),
IdForTests::token_a_definition_id(),
)])
}
fn cc_remove_token_b_with_surplus() -> ChainedCall {
let mut vault_b_auth = AccountWithMetadataForTests::vault_b_with_surplus();
vault_b_auth.is_authorized = true;
ChainedCall::new(
TOKEN_PROGRAM_ID,
vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()],
&token_core::Instruction::Transfer {
amount_to_transfer: BalanceForTests::remove_actual_b_with_surplus(),
},
)
.with_pda_seeds(vec![compute_vault_pda_seed(
IdForTests::pool_definition_id(),
IdForTests::token_b_definition_id(),
)])
}
fn cc_remove_pool_lp() -> ChainedCall {
let mut pool_lp_auth = AccountWithMetadataForTests::pool_lp_init();
pool_lp_auth.is_authorized = true;
@ -619,6 +675,38 @@ impl AccountWithMetadataForTests {
}
}
fn vault_a_with_surplus() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_a_definition_id(),
balance: BalanceForTests::vault_a_balance_with_surplus(),
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::vault_a_id(),
}
}
fn vault_b_with_surplus() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_b_definition_id(),
balance: BalanceForTests::vault_b_balance_with_surplus(),
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::vault_b_id(),
}
}
fn vault_a_init_high() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
@ -1761,6 +1849,40 @@ fn test_call_remove_liquidity_min_bal_zero_2() {
);
}
#[should_panic(expected = "Reserve for Token A exceeds vault balance")]
#[test]
fn test_call_remove_liquidity_reserves_vault_mismatch_1() {
let _post_states = remove_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init_low(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(),
BalanceForTests::remove_min_amount_a(),
BalanceForTests::remove_min_amount_b(),
);
}
#[should_panic(expected = "Reserve for Token B exceeds vault balance")]
#[test]
fn test_call_remove_liquidity_reserves_vault_mismatch_2() {
let _post_states = remove_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init_low(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(),
BalanceForTests::remove_min_amount_a(),
BalanceForTests::remove_min_amount_b(),
);
}
#[test]
fn test_call_remove_liquidity_chained_call_successful() {
let (post_states, chained_calls) = remove_liquidity(
@ -1792,6 +1914,37 @@ fn test_call_remove_liquidity_chained_call_successful() {
assert!(chained_call_lp == ChainedCallForTests::cc_remove_pool_lp());
}
#[test]
fn test_call_remove_liquidity_chained_call_with_vault_surplus_successful() {
let (post_states, chained_calls) = remove_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_with_surplus(),
AccountWithMetadataForTests::vault_b_with_surplus(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(),
BalanceForTests::remove_min_amount_a(),
BalanceForTests::remove_min_amount_b_surplus(),
);
let pool_post = post_states[0].clone();
assert!(
AccountWithMetadataForTests::pool_definition_remove_successful().account
== *pool_post.account()
);
let chained_call_lp = chained_calls[0].clone();
let chained_call_b = chained_calls[1].clone();
let chained_call_a = chained_calls[2].clone();
assert!(chained_call_a == ChainedCallForTests::cc_remove_token_a_with_surplus());
assert!(chained_call_b == ChainedCallForTests::cc_remove_token_b_with_surplus());
assert!(chained_call_lp == ChainedCallForTests::cc_remove_pool_lp());
}
#[should_panic(expected = "Balances must be nonzero")]
#[test]
fn test_call_new_definition_with_zero_balance_1() {

View File

@ -144,6 +144,14 @@ impl Balances {
500
}
fn remove_min_a_with_surplus() -> u128 {
1_100
}
fn remove_min_b_with_surplus() -> u128 {
550
}
fn add_min_lp() -> u128 {
1_000
}
@ -228,6 +236,22 @@ impl Balances {
2_000
}
fn vault_a_with_surplus() -> u128 {
5_500
}
fn vault_b_with_surplus() -> u128 {
2_750
}
fn vault_a_remove_with_surplus() -> u128 {
4_400
}
fn vault_b_remove_with_surplus() -> u128 {
2_200
}
fn user_a_remove() -> u128 {
11_000
}
@ -236,6 +260,14 @@ impl Balances {
10_500
}
fn user_a_remove_with_surplus() -> u128 {
11_100
}
fn user_b_remove_with_surplus() -> u128 {
10_550
}
fn user_lp_remove() -> u128 {
1_000
}
@ -369,6 +401,30 @@ impl Accounts {
}
}
fn vault_a_with_surplus() -> Account {
Account {
program_owner: Ids::token_program(),
balance: 0_u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: Ids::token_a_definition(),
balance: Balances::vault_a_with_surplus(),
}),
nonce: Nonce(0),
}
}
fn vault_b_with_surplus() -> Account {
Account {
program_owner: Ids::token_program(),
balance: 0_u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: Ids::token_b_definition(),
balance: Balances::vault_b_with_surplus(),
}),
nonce: Nonce(0),
}
}
fn user_lp_holding() -> Account {
Account {
program_owner: Ids::token_program(),
@ -668,6 +724,30 @@ impl Accounts {
}
}
fn vault_a_remove_with_surplus() -> Account {
Account {
program_owner: Ids::token_program(),
balance: 0_u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: Ids::token_a_definition(),
balance: Balances::vault_a_remove_with_surplus(),
}),
nonce: Nonce(0),
}
}
fn vault_b_remove_with_surplus() -> Account {
Account {
program_owner: Ids::token_program(),
balance: 0_u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: Ids::token_b_definition(),
balance: Balances::vault_b_remove_with_surplus(),
}),
nonce: Nonce(0),
}
}
fn user_a_holding_remove() -> Account {
Account {
program_owner: Ids::token_program(),
@ -692,6 +772,30 @@ impl Accounts {
}
}
fn user_a_holding_remove_with_surplus() -> Account {
Account {
program_owner: Ids::token_program(),
balance: 0_u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: Ids::token_a_definition(),
balance: Balances::user_a_remove_with_surplus(),
}),
nonce: Nonce(0),
}
}
fn user_b_holding_remove_with_surplus() -> Account {
Account {
program_owner: Ids::token_program(),
balance: 0_u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: Ids::token_b_definition(),
balance: Balances::user_b_remove_with_surplus(),
}),
nonce: Nonce(0),
}
}
fn user_lp_holding_remove() -> Account {
Account {
program_owner: Ids::token_program(),
@ -1055,6 +1159,69 @@ fn amm_remove_liquidity_insufficient_user_lp_fails() {
assert!(state.transition_from_public_transaction(&tx, 0).is_err());
}
#[test]
fn amm_remove_liquidity_with_surplus() {
let mut state = state_for_amm_tests();
state.force_insert_account(Ids::vault_a(), Accounts::vault_a_with_surplus());
state.force_insert_account(Ids::vault_b(), Accounts::vault_b_with_surplus());
let instruction = amm_core::Instruction::RemoveLiquidity {
remove_liquidity_amount: Balances::remove_lp(),
min_amount_to_remove_token_a: Balances::remove_min_a_with_surplus(),
min_amount_to_remove_token_b: Balances::remove_min_b_with_surplus(),
};
let message = public_transaction::Message::try_new(
Ids::amm_program(),
vec![
Ids::pool_definition(),
Ids::vault_a(),
Ids::vault_b(),
Ids::token_lp_definition(),
Ids::user_a(),
Ids::user_b(),
Ids::user_lp(),
],
vec![Nonce(0)],
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_lp()]);
let tx = PublicTransaction::new(message, witness_set);
state.transition_from_public_transaction(&tx, 0).unwrap();
assert_eq!(
state.get_account_by_id(Ids::pool_definition()),
Accounts::pool_definition_remove()
);
assert_eq!(
state.get_account_by_id(Ids::vault_a()),
Accounts::vault_a_remove_with_surplus()
);
assert_eq!(
state.get_account_by_id(Ids::vault_b()),
Accounts::vault_b_remove_with_surplus()
);
assert_eq!(
state.get_account_by_id(Ids::token_lp_definition()),
Accounts::token_lp_definition_remove()
);
assert_eq!(
state.get_account_by_id(Ids::user_a()),
Accounts::user_a_holding_remove_with_surplus()
);
assert_eq!(
state.get_account_by_id(Ids::user_b()),
Accounts::user_b_holding_remove_with_surplus()
);
assert_eq!(
state.get_account_by_id(Ids::user_lp()),
Accounts::user_lp_holding_remove()
);
}
#[test]
fn amm_new_definition_inactive_initialized_pool_and_uninit_user_lp() {
let mut state = state_for_amm_tests_with_new_def();