feat(amm): apply trading fees to LP accounting

Implement Uniswap V2-style fees-in-reserves: the full swap_amount_in is
deposited into the reserve (growing k = reserve_a * reserve_b), while
only the fee-adjusted effective_amount_in is used to compute the output
amount. This means LPs earn fees proportionally on every removal via
k-growth rather than through a separate vault surplus.

- swap_logic: add fee_bps parameter; compute effective_amount_in for
  output formula only; return full swap_amount_in as the reserve deposit
- Fix all integration test fixture values to match fees-in-reserves math
- Remove dead-code vault_a/b_init_zero helpers from unit tests
This commit is contained in:
Ricardo Guilherme Schmidt 2026-03-31 23:15:10 -03:00 committed by r4bbit
parent 1f8eea8442
commit 7199d594e9
No known key found for this signature in database
GPG Key ID: E95F1E9447DC91A9
4 changed files with 663 additions and 87 deletions

View File

@ -1,6 +1,9 @@
use std::num::NonZeroU128; use std::num::NonZeroU128;
use amm_core::{assert_supported_fee_tier, compute_liquidity_token_pda_seed, PoolDefinition}; use amm_core::{
assert_supported_fee_tier, compute_liquidity_token_pda_seed, read_vault_fungible_balances,
PoolDefinition,
};
use nssa_core::{ use nssa_core::{
account::{AccountWithMetadata, Data}, account::{AccountWithMetadata, Data},
program::{AccountPostState, ChainedCall}, program::{AccountPostState, ChainedCall},
@ -44,33 +47,9 @@ pub fn add_liquidity(
"Both max-balances must be nonzero" "Both max-balances must be nonzero"
); );
// 2. Determine deposit amount let (vault_a_balance, vault_b_balance) =
let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) read_vault_fungible_balances("Add liquidity", &vault_a, &vault_b);
.expect("Add liquidity: AMM Program expects valid Token Holding Account for Vault B");
let token_core::TokenHolding::Fungible {
definition_id: _,
balance: vault_b_balance,
} = vault_b_token_holding
else {
panic!(
"Add liquidity: AMM Program expects valid Fungible Token Holding Account for Vault B"
);
};
let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data)
.expect("Add liquidity: AMM Program expects valid Token Holding Account for Vault A");
let token_core::TokenHolding::Fungible {
definition_id: _,
balance: vault_a_balance,
} = vault_a_token_holding
else {
panic!(
"Add liquidity: AMM Program expects valid Fungible Token Holding Account for Vault A"
);
};
assert!(pool_def_data.reserve_a != 0, "Reserves must be nonzero");
assert!(pool_def_data.reserve_b != 0, "Reserves must be nonzero");
assert!( assert!(
vault_a_balance >= pool_def_data.reserve_a, vault_a_balance >= pool_def_data.reserve_a,
"Vaults' balances must be at least the reserve amounts" "Vaults' balances must be at least the reserve amounts"
@ -80,7 +59,10 @@ pub fn add_liquidity(
"Vaults' balances must be at least the reserve amounts" "Vaults' balances must be at least the reserve amounts"
); );
// Calculate actual_amounts // 2. Determine deposit amount
assert!(pool_def_data.reserve_a != 0, "Reserves must be nonzero");
assert!(pool_def_data.reserve_b != 0, "Reserves must be nonzero");
let ideal_a: u128 = pool_def_data let ideal_a: u128 = pool_def_data
.reserve_a .reserve_a
.checked_mul(max_amount_to_add_token_b) .checked_mul(max_amount_to_add_token_b)

View File

@ -1,4 +1,6 @@
use amm_core::{assert_supported_fee_tier, MINIMUM_LIQUIDITY}; use amm_core::{
assert_supported_fee_tier, read_vault_fungible_balances, FEE_BPS_DENOMINATOR, MINIMUM_LIQUIDITY,
};
pub use amm_core::{compute_liquidity_token_pda_seed, compute_vault_pda_seed, PoolDefinition}; pub use amm_core::{compute_liquidity_token_pda_seed, compute_vault_pda_seed, PoolDefinition};
use nssa_core::{ use nssa_core::{
account::{AccountId, AccountWithMetadata, Data}, account::{AccountId, AccountWithMetadata, Data},
@ -28,31 +30,13 @@ fn validate_swap_setup(
"Vault B was not provided" "Vault B was not provided"
); );
let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) let (vault_a_balance, vault_b_balance) =
.expect("AMM Program expects a valid Token Holding Account for Vault A"); read_vault_fungible_balances("Validate swap setup", vault_a, vault_b);
let token_core::TokenHolding::Fungible {
definition_id: _,
balance: vault_a_balance,
} = vault_a_token_holding
else {
panic!("AMM Program expects a valid Fungible Token Holding Account for Vault A");
};
assert!( assert!(
vault_a_balance >= pool_def_data.reserve_a, vault_a_balance >= pool_def_data.reserve_a,
"Reserve for Token A exceeds vault balance" "Reserve for Token A exceeds vault balance"
); );
let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data)
.expect("AMM Program expects a valid Token Holding Account for Vault B");
let token_core::TokenHolding::Fungible {
definition_id: _,
balance: vault_b_balance,
} = vault_b_token_holding
else {
panic!("AMM Program expects a valid Fungible Token Holding Account for Vault B");
};
assert!( assert!(
vault_b_balance >= pool_def_data.reserve_b, vault_b_balance >= pool_def_data.reserve_b,
"Reserve for Token B exceeds vault balance" "Reserve for Token B exceeds vault balance"
@ -130,6 +114,7 @@ pub fn swap_exact_input(
user_holding_b.clone(), user_holding_b.clone(),
swap_amount_in, swap_amount_in,
min_amount_out, min_amount_out,
pool_def_data.fees,
pool_def_data.reserve_a, pool_def_data.reserve_a,
pool_def_data.reserve_b, pool_def_data.reserve_b,
pool.account_id, pool.account_id,
@ -144,6 +129,7 @@ pub fn swap_exact_input(
user_holding_a.clone(), user_holding_a.clone(),
swap_amount_in, swap_amount_in,
min_amount_out, min_amount_out,
pool_def_data.fees,
pool_def_data.reserve_b, pool_def_data.reserve_b,
pool_def_data.reserve_a, pool_def_data.reserve_a,
pool.account_id, pool.account_id,
@ -178,19 +164,29 @@ fn swap_logic(
user_withdraw: AccountWithMetadata, user_withdraw: AccountWithMetadata,
swap_amount_in: u128, swap_amount_in: u128,
min_amount_out: u128, min_amount_out: u128,
fee_bps: u128,
reserve_deposit_vault_amount: u128, reserve_deposit_vault_amount: u128,
reserve_withdraw_vault_amount: u128, reserve_withdraw_vault_amount: u128,
pool_id: AccountId, pool_id: AccountId,
) -> (Vec<ChainedCall>, u128, u128) { ) -> (Vec<ChainedCall>, u128, u128) {
// Compute withdraw amount let effective_amount_in = swap_amount_in
// Maintains pool constant product .checked_mul(FEE_BPS_DENOMINATOR - fee_bps)
// k = pool_def_data.reserve_a * pool_def_data.reserve_b; .expect("swap_amount_in * (FEE_BPS_DENOMINATOR - fee_bps) overflows u128")
/ FEE_BPS_DENOMINATOR;
assert!(
effective_amount_in != 0,
"Effective swap amount should be nonzero"
);
// Compute the withdraw amount using the fee-adjusted input for pricing.
// The recorded pool reserves are updated later with the full
// `swap_amount_in`, so LP fees accrue inside `reserve_*` via invariant
// growth rather than as a separate vault balance surplus over `reserve_*`.
let withdraw_amount = reserve_withdraw_vault_amount let withdraw_amount = reserve_withdraw_vault_amount
.checked_mul(swap_amount_in) .checked_mul(effective_amount_in)
.expect("reserve * amount_in overflows u128") .expect("reserve * effective_amount_in overflows u128")
/ reserve_deposit_vault_amount / reserve_deposit_vault_amount
.checked_add(swap_amount_in) .checked_add(effective_amount_in)
.expect("reserve + swap_amount_in overflows u128"); .expect("reserve + effective_amount_in overflows u128");
// Slippage check // Slippage check
assert!( assert!(
@ -259,6 +255,7 @@ pub fn swap_exact_output(
max_amount_in, max_amount_in,
pool_def_data.reserve_a, pool_def_data.reserve_a,
pool_def_data.reserve_b, pool_def_data.reserve_b,
pool_def_data.fees,
pool.account_id, pool.account_id,
); );
@ -273,6 +270,7 @@ pub fn swap_exact_output(
max_amount_in, max_amount_in,
pool_def_data.reserve_b, pool_def_data.reserve_b,
pool_def_data.reserve_a, pool_def_data.reserve_a,
pool_def_data.fees,
pool.account_id, pool.account_id,
); );
@ -307,6 +305,7 @@ fn exact_output_swap_logic(
max_amount_in: u128, max_amount_in: u128,
reserve_deposit_vault_amount: u128, reserve_deposit_vault_amount: u128,
reserve_withdraw_vault_amount: u128, reserve_withdraw_vault_amount: u128,
fee_bps: u128,
pool_id: AccountId, pool_id: AccountId,
) -> (Vec<ChainedCall>, u128, u128) { ) -> (Vec<ChainedCall>, u128, u128) {
// Guard: exact_amount_out must be nonzero // Guard: exact_amount_out must be nonzero
@ -318,12 +317,28 @@ fn exact_output_swap_logic(
"Exact amount out exceeds reserve" "Exact amount out exceeds reserve"
); );
// Compute deposit amount using ceiling division // Compute the minimum effective input required to achieve exact_amount_out
// Formula: amount_in = ceil(reserve_in * exact_amount_out / (reserve_out - exact_amount_out)) // using the same floor-rounded fee application as swap_exact_input.
let deposit_amount = reserve_deposit_vault_amount //
// Solve constant product for effective_in (fee already removed):
// effective_in >= ceil(reserve_in * amount_out / (reserve_out - amount_out))
let effective_in_numerator = reserve_deposit_vault_amount
.checked_mul(exact_amount_out) .checked_mul(exact_amount_out)
.expect("reserve * amount_out overflows u128") .expect("reserve * amount_out overflows u128");
.div_ceil(reserve_withdraw_vault_amount - exact_amount_out); let effective_in_denominator = reserve_withdraw_vault_amount
.checked_sub(exact_amount_out)
.expect("reserve_out - amount_out underflows");
let effective_in_min = effective_in_numerator.div_ceil(effective_in_denominator);
// Lift back to gross input so that
// floor(gross_in * (FEE_DENOM - fee) / FEE_DENOM) >= effective_in_min
let fee_multiplier = FEE_BPS_DENOMINATOR
.checked_sub(fee_bps)
.expect("fee_bps exceeds fee denominator");
let deposit_amount = effective_in_min
.checked_mul(FEE_BPS_DENOMINATOR)
.expect("effective_in * FEE_DENOM overflows u128")
.div_ceil(fee_multiplier);
// Slippage check // Slippage check
assert!( assert!(

View File

@ -4,8 +4,9 @@ use std::num::NonZero;
use amm_core::{ use amm_core::{
compute_liquidity_token_pda, compute_liquidity_token_pda_seed, compute_lp_lock_holding_pda, compute_liquidity_token_pda, compute_liquidity_token_pda_seed, compute_lp_lock_holding_pda,
compute_pool_pda, compute_vault_pda, compute_vault_pda_seed, PoolDefinition, FEE_TIER_BPS_1, compute_pool_pda, compute_vault_pda, compute_vault_pda_seed, PoolDefinition,
FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5, MINIMUM_LIQUIDITY, FEE_BPS_DENOMINATOR, FEE_TIER_BPS_1, FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5,
MINIMUM_LIQUIDITY,
}; };
use nssa_core::{ use nssa_core::{
account::{Account, AccountId, AccountWithMetadata, Data, Nonce}, account::{Account, AccountId, AccountWithMetadata, Data, Nonce},
@ -103,6 +104,16 @@ impl BalanceForTests {
200 200
} }
fn effective_swap_in_a() -> u128 {
BalanceForTests::add_max_amount_a() * (FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier())
/ FEE_BPS_DENOMINATOR
}
fn effective_swap_in_b() -> u128 {
BalanceForTests::add_max_amount_b() * (FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier())
/ FEE_BPS_DENOMINATOR
}
fn add_max_amount_a_low() -> u128 { fn add_max_amount_a_low() -> u128 {
10 10
} }
@ -178,13 +189,13 @@ impl BalanceForTests {
} }
fn swap_amount_out_b() -> u128 { fn swap_amount_out_b() -> u128 {
(BalanceForTests::vault_b_reserve_init() * BalanceForTests::add_max_amount_a()) (BalanceForTests::vault_b_reserve_init() * BalanceForTests::effective_swap_in_a())
/ (BalanceForTests::vault_a_reserve_init() + BalanceForTests::add_max_amount_a()) / (BalanceForTests::vault_a_reserve_init() + BalanceForTests::effective_swap_in_a())
} }
fn swap_amount_out_a() -> u128 { fn swap_amount_out_a() -> u128 {
(BalanceForTests::vault_a_reserve_init() * BalanceForTests::add_max_amount_b()) (BalanceForTests::vault_a_reserve_init() * BalanceForTests::effective_swap_in_b())
/ (BalanceForTests::vault_b_reserve_init() + BalanceForTests::add_max_amount_b()) / (BalanceForTests::vault_b_reserve_init() + BalanceForTests::effective_swap_in_b())
} }
fn add_delta_lp_successful() -> u128 { fn add_delta_lp_successful() -> u128 {
@ -276,7 +287,10 @@ impl ChainedCallForTests {
} }
fn cc_swap_exact_output_token_a_test_1() -> ChainedCall { fn cc_swap_exact_output_token_a_test_1() -> ChainedCall {
let swap_amount: u128 = 498; // reserve_in=1000, amount_out=166, fee=30bps
// required_effective_in = ceil(1000 * 166 / 334) = 498
// deposit = ceil(498 * 10000 / 9970) = 500
let swap_amount: u128 = 500;
ChainedCall::new( ChainedCall::new(
TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID,
@ -329,7 +343,10 @@ impl ChainedCallForTests {
} }
fn cc_swap_exact_output_token_b_test_2() -> ChainedCall { fn cc_swap_exact_output_token_b_test_2() -> ChainedCall {
let swap_amount: u128 = 200; // reserve_in=500, amount_out=285, fee=30bps
// required_effective_in = ceil(500 * 285 / 715) = 200
// deposit = ceil(200 * 10000 / 9970) = 201
let swap_amount: u128 = 201;
ChainedCall::new( ChainedCall::new(
TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID,
@ -343,6 +360,36 @@ impl ChainedCallForTests {
) )
} }
fn cc_swap_rounding_boundary_token_a_in() -> ChainedCall {
ChainedCall::new(
TOKEN_PROGRAM_ID,
vec![
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::vault_a_init(),
],
&token_core::Instruction::Transfer {
amount_to_transfer: 3,
},
)
}
fn cc_swap_rounding_boundary_token_b_out() -> ChainedCall {
let mut vault_b_auth = AccountWithMetadataForTests::vault_b_init();
vault_b_auth.is_authorized = true;
ChainedCall::new(
TOKEN_PROGRAM_ID,
vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()],
&token_core::Instruction::Transfer {
amount_to_transfer: 1,
},
)
.with_pda_seeds(vec![compute_vault_pda_seed(
IdForTests::pool_definition_id(),
IdForTests::token_b_definition_id(),
)])
}
fn cc_add_token_a() -> ChainedCall { fn cc_add_token_a() -> ChainedCall {
ChainedCall::new( ChainedCall::new(
TOKEN_PROGRAM_ID, TOKEN_PROGRAM_ID,
@ -885,6 +932,29 @@ impl AccountWithMetadataForTests {
} }
} }
fn pool_definition_swap_rounding_boundary_init() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: ProgramId::default(),
balance: 0u128,
data: Data::from(&PoolDefinition {
definition_token_a_id: IdForTests::token_a_definition_id(),
definition_token_b_id: IdForTests::token_b_definition_id(),
vault_a_id: IdForTests::vault_a_id(),
vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: MINIMUM_LIQUIDITY,
reserve_a: 1_000,
reserve_b: 1_000,
fees: FEE_TIER_BPS_30,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::pool_definition_id(),
}
}
fn pool_definition_init_reserve_a_zero() -> AccountWithMetadata { fn pool_definition_init_reserve_a_zero() -> AccountWithMetadata {
AccountWithMetadata { AccountWithMetadata {
account: Account { account: Account {
@ -1024,6 +1094,9 @@ impl AccountWithMetadataForTests {
} }
fn pool_definition_swap_exact_output_test_1() -> AccountWithMetadata { fn pool_definition_swap_exact_output_test_1() -> AccountWithMetadata {
// swap token_a in for 166 token_b out, fee=30bps
// reserve_a: 1000 + 500 = 1500 (gross deposit, see
// cc_swap_exact_output_token_a_test_1) reserve_b: 500 - 166 = 334
AccountWithMetadata { AccountWithMetadata {
account: Account { account: Account {
program_owner: ProgramId::default(), program_owner: ProgramId::default(),
@ -1035,7 +1108,7 @@ impl AccountWithMetadataForTests {
vault_b_id: IdForTests::vault_b_id(), vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(), liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: BalanceForTests::lp_supply_init(), liquidity_pool_supply: BalanceForTests::lp_supply_init(),
reserve_a: 1498_u128, reserve_a: 1500_u128,
reserve_b: 334_u128, reserve_b: 334_u128,
fees: BalanceForTests::fee_tier(), fees: BalanceForTests::fee_tier(),
}), }),
@ -1059,7 +1132,7 @@ impl AccountWithMetadataForTests {
liquidity_pool_id: IdForTests::token_lp_definition_id(), liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: BalanceForTests::lp_supply_init(), liquidity_pool_supply: BalanceForTests::lp_supply_init(),
reserve_a: 715_u128, reserve_a: 715_u128,
reserve_b: 700_u128, reserve_b: 701_u128,
fees: BalanceForTests::fee_tier(), fees: BalanceForTests::fee_tier(),
}), }),
nonce: Nonce(0), nonce: Nonce(0),
@ -1069,6 +1142,29 @@ impl AccountWithMetadataForTests {
} }
} }
fn pool_definition_swap_rounding_boundary_post() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: ProgramId::default(),
balance: 0_u128,
data: Data::from(&PoolDefinition {
definition_token_a_id: IdForTests::token_a_definition_id(),
definition_token_b_id: IdForTests::token_b_definition_id(),
vault_a_id: IdForTests::vault_a_id(),
vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: MINIMUM_LIQUIDITY,
reserve_a: 1003_u128,
reserve_b: 999_u128,
fees: FEE_TIER_BPS_30,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::pool_definition_id(),
}
}
fn pool_definition_add_zero_lp() -> AccountWithMetadata { fn pool_definition_add_zero_lp() -> AccountWithMetadata {
AccountWithMetadata { AccountWithMetadata {
account: Account { account: Account {
@ -1115,6 +1211,29 @@ impl AccountWithMetadataForTests {
} }
} }
fn pool_definition_init_low_balances() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: ProgramId::default(),
balance: 0u128,
data: Data::from(&PoolDefinition {
definition_token_a_id: IdForTests::token_a_definition_id(),
definition_token_b_id: IdForTests::token_b_definition_id(),
vault_a_id: IdForTests::vault_a_id(),
vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: MINIMUM_LIQUIDITY,
reserve_a: BalanceForTests::vault_a_reserve_low(),
reserve_b: BalanceForTests::vault_b_reserve_low(),
fees: BalanceForTests::fee_tier(),
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::pool_definition_id(),
}
}
fn pool_definition_remove_successful() -> AccountWithMetadata { fn pool_definition_remove_successful() -> AccountWithMetadata {
AccountWithMetadata { AccountWithMetadata {
account: Account { account: Account {
@ -1342,6 +1461,40 @@ fn test_call_add_liquidity_zero_balance_2() {
); );
} }
#[should_panic(expected = "Vaults' balances must be at least the reserve amounts")]
#[test]
fn test_call_add_liquidity_vault_a_balance_below_reserve() {
let _post_states = add_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init_low(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::add_max_amount_b(),
);
}
#[should_panic(expected = "Vaults' balances must be at least the reserve amounts")]
#[test]
fn test_call_add_liquidity_vault_b_balance_below_reserve() {
let _post_states = add_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init_low(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::add_max_amount_b(),
);
}
#[should_panic(expected = "Vaults' balances must be at least the reserve amounts")] #[should_panic(expected = "Vaults' balances must be at least the reserve amounts")]
#[test] #[test]
fn test_call_add_liquidity_vault_insufficient_balance_1() { fn test_call_add_liquidity_vault_insufficient_balance_1() {
@ -1353,9 +1506,9 @@ fn test_call_add_liquidity_vault_insufficient_balance_1() {
AccountWithMetadataForTests::user_holding_a(), AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(), AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(), AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::add_max_amount_a()).unwrap(), NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::add_max_amount_b(), BalanceForTests::add_max_amount_b(),
BalanceForTests::add_min_amount_lp(),
); );
} }
@ -1370,9 +1523,9 @@ fn test_call_add_liquidity_vault_insufficient_balance_2() {
AccountWithMetadataForTests::user_holding_a(), AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(), AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(), AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::add_max_amount_a()).unwrap(), NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::add_max_amount_b(), BalanceForTests::add_max_amount_b(),
BalanceForTests::add_min_amount_lp(),
); );
} }
@ -2052,6 +2205,84 @@ fn test_call_swap_below_min_out() {
); );
} }
#[should_panic(expected = "Effective swap amount should be nonzero")]
#[test]
fn test_call_swap_effective_amount_zero() {
let _post_states = swap_exact_input(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
1,
0,
IdForTests::token_a_definition_id(),
);
}
#[should_panic(expected = "Withdraw amount should be nonzero")]
#[test]
fn test_call_swap_output_rounds_to_zero() {
let _post_states = swap_exact_input(
AccountWithMetadataForTests::pool_definition_init_low_balances(),
AccountWithMetadataForTests::vault_a_init_low(),
AccountWithMetadataForTests::vault_b_init_low(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
2,
0,
IdForTests::token_a_definition_id(),
);
}
#[should_panic(expected = "Withdraw amount is less than minimal amount out")]
#[test]
fn test_call_swap_exact_input_rejects_amount_that_rounds_down_below_target_output() {
let _post_states = swap_exact_input(
AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
2,
1,
IdForTests::token_a_definition_id(),
);
}
#[test]
fn test_call_swap_exact_input_accepts_smallest_amount_for_rounded_boundary() {
let (post_states, chained_calls) = swap_exact_input(
AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
3,
1,
IdForTests::token_a_definition_id(),
);
let pool_post = post_states[0].clone();
assert_eq!(
AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_post().account,
*pool_post.account()
);
let chained_call_a = chained_calls[0].clone();
let chained_call_b = chained_calls[1].clone();
assert_eq!(
chained_call_a,
ChainedCallForTests::cc_swap_rounding_boundary_token_a_in()
);
assert_eq!(
chained_call_b,
ChainedCallForTests::cc_swap_rounding_boundary_token_b_out()
);
}
#[test] #[test]
fn test_call_swap_chained_call_successful_1() { fn test_call_swap_chained_call_successful_1() {
let (post_states, chained_calls) = swap_exact_input( let (post_states, chained_calls) = swap_exact_input(
@ -2317,6 +2548,74 @@ fn call_swap_exact_output_chained_call_successful_2() {
); );
} }
// The minimum effective input for exact_amount_out=166 on the 1000/500 pool is 498.
// After fee rounding, the true minimum gross input is 500, so 499 must be rejected.
#[should_panic(expected = "Required input exceeds maximum amount in")]
#[test]
fn call_swap_exact_output_fee_enforced() {
let _post_states = swap_exact_output(
AccountWithMetadataForTests::pool_definition_swap_exact_output_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
166_u128, // exact_amount_out: token_b
499_u128, // max_amount_in: still one short after fee rounding
IdForTests::token_a_definition_id(),
);
}
// On a 1000/1000 pool at 0.3%, exact_amount_out = 1 requires gross input 3.
// max_amount_in = 2 must be rejected because the exact-input path would round
// 2 down to effective_in = 1 and still produce 0 output.
#[should_panic(expected = "Required input exceeds maximum amount in")]
#[test]
fn call_swap_exact_output_rejects_max_in_that_rounds_down_below_target_output() {
let _post_states = swap_exact_output(
AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
1,
2,
IdForTests::token_a_definition_id(),
);
}
#[test]
fn call_swap_exact_output_accepts_smallest_max_in_for_rounded_boundary() {
let (post_states, chained_calls) = swap_exact_output(
AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
1,
3,
IdForTests::token_a_definition_id(),
);
let pool_post = post_states[0].clone();
assert_eq!(
AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_post().account,
*pool_post.account()
);
let chained_call_a = chained_calls[0].clone();
let chained_call_b = chained_calls[1].clone();
assert_eq!(
chained_call_a,
ChainedCallForTests::cc_swap_rounding_boundary_token_a_in()
);
assert_eq!(
chained_call_b,
ChainedCallForTests::cc_swap_rounding_boundary_token_b_out()
);
}
// Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, // Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode,
// making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes, // making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes,
// so an attacker receives `exact_amount_out` tokens while paying nothing. // so an attacker receives `exact_amount_out` tokens while paying nothing.
@ -2869,7 +3168,7 @@ fn remove_liquidity_overflow_protection() {
); );
} }
#[should_panic(expected = "reserve * amount_in overflows u128")] #[should_panic(expected = "reserve * effective_amount_in overflows u128")]
#[test] #[test]
fn swap_exact_input_overflow_protection() { fn swap_exact_input_overflow_protection() {
let large_reserve: u128 = u128::MAX / 2 + 1; let large_reserve: u128 = u128::MAX / 2 + 1;
@ -2924,7 +3223,8 @@ fn swap_exact_input_overflow_protection() {
account_id: IdForTests::vault_b_id(), account_id: IdForTests::vault_b_id(),
}; };
// Swap token_a in: withdraw_amount = reserve_b * swap_amount_in / (reserve_a + swap_amount_in) // Swap token_a in: withdraw_amount = reserve_b * effective_amount_in / (reserve_a +
// effective_amount_in) With fee_bps=30: effective_amount_in = 3 * 9970 / 10000 = 2
// reserve_b is large, so reserve_b * 2 overflows // reserve_b is large, so reserve_b * 2 overflows
let _result = swap_exact_input( let _result = swap_exact_input(
pool, pool,
@ -2932,7 +3232,7 @@ fn swap_exact_input_overflow_protection() {
vault_b, vault_b,
AccountWithMetadataForTests::user_holding_a(), AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(), AccountWithMetadataForTests::user_holding_b(),
2, 3,
1, 1,
IdForTests::token_a_definition_id(), IdForTests::token_a_definition_id(),
); );

View File

@ -164,8 +164,16 @@ impl Balances {
200 200
} }
fn reserve_a_swap_1() -> u128 {
3_575
}
fn reserve_b_swap_1() -> u128 {
3_500
}
fn vault_a_swap_1() -> u128 { fn vault_a_swap_1() -> u128 {
3_572 3_575
} }
fn vault_b_swap_1() -> u128 { fn vault_b_swap_1() -> u128 {
@ -173,19 +181,27 @@ impl Balances {
} }
fn user_a_swap_1() -> u128 { fn user_a_swap_1() -> u128 {
11_428 11_425
} }
fn user_b_swap_1() -> u128 { fn user_b_swap_1() -> u128 {
9_000 9_000
} }
fn reserve_a_swap_2() -> u128 {
6_000
}
fn reserve_b_swap_2() -> u128 {
2_085
}
fn vault_a_swap_2() -> u128 { fn vault_a_swap_2() -> u128 {
6_000 6_000
} }
fn vault_b_swap_2() -> u128 { fn vault_b_swap_2() -> u128 {
2_084 2_085
} }
fn user_a_swap_2() -> u128 { fn user_a_swap_2() -> u128 {
@ -193,7 +209,7 @@ impl Balances {
} }
fn user_b_swap_2() -> u128 { fn user_b_swap_2() -> u128 {
10_416 10_415
} }
fn vault_a_add() -> u128 { fn vault_a_add() -> u128 {
@ -405,8 +421,8 @@ impl Accounts {
vault_b_id: Ids::vault_b(), vault_b_id: Ids::vault_b(),
liquidity_pool_id: Ids::token_lp_definition(), liquidity_pool_id: Ids::token_lp_definition(),
liquidity_pool_supply: Balances::pool_lp_supply_init(), liquidity_pool_supply: Balances::pool_lp_supply_init(),
reserve_a: Balances::vault_a_swap_1(), reserve_a: Balances::reserve_a_swap_1(),
reserve_b: Balances::vault_b_swap_1(), reserve_b: Balances::reserve_b_swap_1(),
fees: Balances::fee_tier(), fees: Balances::fee_tier(),
}), }),
nonce: Nonce(0), nonce: Nonce(0),
@ -472,8 +488,8 @@ impl Accounts {
vault_b_id: Ids::vault_b(), vault_b_id: Ids::vault_b(),
liquidity_pool_id: Ids::token_lp_definition(), liquidity_pool_id: Ids::token_lp_definition(),
liquidity_pool_supply: Balances::pool_lp_supply_init(), liquidity_pool_supply: Balances::pool_lp_supply_init(),
reserve_a: Balances::vault_a_swap_2(), reserve_a: Balances::reserve_a_swap_2(),
reserve_b: Balances::vault_b_swap_2(), reserve_b: Balances::reserve_b_swap_2(),
fees: Balances::fee_tier(), fees: Balances::fee_tier(),
}), }),
nonce: Nonce(0), nonce: Nonce(0),
@ -918,6 +934,10 @@ fn state_for_amm_tests_with_new_def() -> V03State {
state state
} }
fn current_nonce(state: &V03State, account_id: AccountId) -> Nonce {
state.get_account_by_id(account_id).nonce
}
fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), NssaError> { fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), NssaError> {
let instruction = amm_core::Instruction::NewDefinition { let instruction = amm_core::Instruction::NewDefinition {
token_a_amount: Balances::vault_a_init(), token_a_amount: Balances::vault_a_init(),
@ -938,7 +958,10 @@ fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), Ns
Ids::user_b(), Ids::user_b(),
Ids::user_lp(), Ids::user_lp(),
], ],
vec![Nonce(0), Nonce(0)], vec![
current_nonce(state, Ids::user_a()),
current_nonce(state, Ids::user_b()),
],
instruction, instruction,
) )
.unwrap(); .unwrap();
@ -954,6 +977,163 @@ fn execute_new_definition(state: &mut V03State, fees: u128) {
try_execute_new_definition(state, fees).unwrap(); try_execute_new_definition(state, fees).unwrap();
} }
fn execute_swap_a_to_b(state: &mut V03State, swap_amount_in: u128, min_amount_out: u128) {
let instruction = amm_core::Instruction::SwapExactInput {
swap_amount_in,
min_amount_out,
token_definition_id_in: Ids::token_a_definition(),
};
let message = public_transaction::Message::try_new(
Ids::amm_program(),
vec![
Ids::pool_definition(),
Ids::vault_a(),
Ids::vault_b(),
Ids::user_a(),
Ids::user_b(),
],
vec![current_nonce(state, Ids::user_a())],
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_a()]);
let tx = PublicTransaction::new(message, witness_set);
state.transition_from_public_transaction(&tx, 0).unwrap();
}
fn execute_swap_b_to_a(state: &mut V03State, swap_amount_in: u128, min_amount_out: u128) {
let instruction = amm_core::Instruction::SwapExactInput {
swap_amount_in,
min_amount_out,
token_definition_id_in: Ids::token_b_definition(),
};
let message = public_transaction::Message::try_new(
Ids::amm_program(),
vec![
Ids::pool_definition(),
Ids::vault_a(),
Ids::vault_b(),
Ids::user_a(),
Ids::user_b(),
],
vec![current_nonce(state, Ids::user_b())],
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_b()]);
let tx = PublicTransaction::new(message, witness_set);
state.transition_from_public_transaction(&tx, 0).unwrap();
}
fn execute_add_liquidity(
state: &mut V03State,
min_amount_liquidity: u128,
max_amount_to_add_token_a: u128,
max_amount_to_add_token_b: u128,
) {
let instruction = amm_core::Instruction::AddLiquidity {
min_amount_liquidity,
max_amount_to_add_token_a,
max_amount_to_add_token_b,
};
let message = public_transaction::Message::try_new(
Ids::amm_program(),
vec![
Ids::pool_definition(),
Ids::vault_a(),
Ids::vault_b(),
Ids::token_lp_definition(),
Ids::user_a(),
Ids::user_b(),
Ids::user_lp(),
],
vec![
current_nonce(state, Ids::user_a()),
current_nonce(state, Ids::user_b()),
],
instruction,
)
.unwrap();
let witness_set =
public_transaction::WitnessSet::for_message(&message, &[&Keys::user_a(), &Keys::user_b()]);
let tx = PublicTransaction::new(message, witness_set);
state.transition_from_public_transaction(&tx, 0).unwrap();
}
fn execute_remove_liquidity(
state: &mut V03State,
remove_liquidity_amount: u128,
min_amount_to_remove_token_a: u128,
min_amount_to_remove_token_b: u128,
) {
let instruction = amm_core::Instruction::RemoveLiquidity {
remove_liquidity_amount,
min_amount_to_remove_token_a,
min_amount_to_remove_token_b,
};
let message = public_transaction::Message::try_new(
Ids::amm_program(),
vec![
Ids::pool_definition(),
Ids::vault_a(),
Ids::vault_b(),
Ids::token_lp_definition(),
Ids::user_a(),
Ids::user_b(),
Ids::user_lp(),
],
vec![current_nonce(state, Ids::user_lp())],
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_lp()]);
let tx = PublicTransaction::new(message, witness_set);
state.transition_from_public_transaction(&tx, 0).unwrap();
}
fn fungible_balance(account: &Account) -> u128 {
let holding = TokenHolding::try_from(&account.data).expect("expected token holding");
let TokenHolding::Fungible {
definition_id: _,
balance,
} = holding
else {
panic!("expected fungible token holding")
};
balance
}
fn pool_definition(account: &Account) -> PoolDefinition {
PoolDefinition::try_from(&account.data).expect("expected pool definition")
}
fn fungible_total_supply(account: &Account) -> u128 {
let definition = TokenDefinition::try_from(&account.data).expect("expected token definition");
let TokenDefinition::Fungible {
name: _,
total_supply,
metadata_id: _,
} = definition
else {
panic!("expected fungible token definition")
};
total_supply
}
#[test] #[test]
fn amm_remove_liquidity() { fn amm_remove_liquidity() {
let mut state = state_for_amm_tests(); let mut state = state_for_amm_tests();
@ -1322,3 +1502,102 @@ fn amm_swap_a_to_b() {
Accounts::user_b_holding_swap_2() Accounts::user_b_holding_swap_2()
); );
} }
#[test]
fn amm_fee_accumulates_across_multiple_swaps_and_pays_out_on_remove() {
let mut state = state_for_amm_tests();
execute_swap_a_to_b(&mut state, 1_000, 200);
execute_swap_b_to_a(&mut state, 1_000, 200);
let pool_before_remove = pool_definition(&state.get_account_by_id(Ids::pool_definition()));
assert_eq!(pool_before_remove.reserve_a, 4_060);
assert_eq!(pool_before_remove.reserve_b, 3_085);
assert_eq!(pool_before_remove.fees, Balances::fee_tier());
let vault_a_before_remove = fungible_balance(&state.get_account_by_id(Ids::vault_a()));
let vault_b_before_remove = fungible_balance(&state.get_account_by_id(Ids::vault_b()));
assert_eq!(vault_a_before_remove, 4_060);
assert_eq!(vault_b_before_remove, 3_085);
assert_eq!(vault_a_before_remove, pool_before_remove.reserve_a);
assert_eq!(vault_b_before_remove, pool_before_remove.reserve_b);
execute_remove_liquidity(&mut state, 1_000, 812, 617);
let pool_after_remove = pool_definition(&state.get_account_by_id(Ids::pool_definition()));
assert_eq!(pool_after_remove.reserve_a, 3_248);
assert_eq!(pool_after_remove.reserve_b, 2_468);
assert_eq!(pool_after_remove.liquidity_pool_supply, 4_000);
let vault_a_after_remove = fungible_balance(&state.get_account_by_id(Ids::vault_a()));
let vault_b_after_remove = fungible_balance(&state.get_account_by_id(Ids::vault_b()));
assert_eq!(vault_a_after_remove, 3_248);
assert_eq!(vault_b_after_remove, 2_468);
assert_eq!(vault_a_after_remove, pool_after_remove.reserve_a);
assert_eq!(vault_b_after_remove, pool_after_remove.reserve_b);
assert_eq!(
fungible_balance(&state.get_account_by_id(Ids::user_a())),
11_752
);
assert_eq!(
fungible_balance(&state.get_account_by_id(Ids::user_b())),
10_032
);
assert_eq!(
fungible_balance(&state.get_account_by_id(Ids::user_lp())),
1_000
);
assert_eq!(
fungible_total_supply(&state.get_account_by_id(Ids::token_lp_definition())),
4_000
);
}
#[test]
fn amm_add_liquidity_after_fee_accrual() {
let mut state = state_for_amm_tests();
execute_swap_a_to_b(&mut state, 1_000, 200);
execute_swap_b_to_a(&mut state, 1_000, 200);
execute_swap_a_to_b(&mut state, 1_000, 200);
execute_swap_b_to_a(&mut state, 1_000, 200);
let pool_before_add = pool_definition(&state.get_account_by_id(Ids::pool_definition()));
let vault_a_before_add = fungible_balance(&state.get_account_by_id(Ids::vault_a()));
let vault_b_before_add = fungible_balance(&state.get_account_by_id(Ids::vault_b()));
assert_eq!(pool_before_add.reserve_a, 3_608);
assert_eq!(pool_before_add.reserve_b, 3_477);
assert_eq!(vault_a_before_add, pool_before_add.reserve_a);
assert_eq!(vault_b_before_add, pool_before_add.reserve_b);
execute_add_liquidity(&mut state, 1_436, 2_000, 1_000);
let pool_after_add = pool_definition(&state.get_account_by_id(Ids::pool_definition()));
let vault_a_after_add = fungible_balance(&state.get_account_by_id(Ids::vault_a()));
let vault_b_after_add = fungible_balance(&state.get_account_by_id(Ids::vault_b()));
assert_eq!(pool_after_add.reserve_a, 4_645);
assert_eq!(pool_after_add.reserve_b, 4_477);
assert_eq!(pool_after_add.liquidity_pool_supply, 6_437);
assert_eq!(vault_a_after_add, pool_after_add.reserve_a);
assert_eq!(vault_b_after_add, pool_after_add.reserve_b);
assert_eq!(
fungible_balance(&state.get_account_by_id(Ids::user_a())),
10_355
);
assert_eq!(
fungible_balance(&state.get_account_by_id(Ids::user_b())),
8_023
);
assert_eq!(
fungible_balance(&state.get_account_by_id(Ids::user_lp())),
3_437
);
assert_eq!(
fungible_total_supply(&state.get_account_by_id(Ids::token_lp_definition())),
6_437
);
}