From 46643941ac92cc2f91a3b09d16d90bf5bc5737c2 Mon Sep 17 00:00:00 2001 From: Andrea Franz Date: Tue, 7 Apr 2026 10:38:14 +0200 Subject: [PATCH] fix(amm): use checked mul/add/sub to avoid overflows/underflows --- amm/src/add.rs | 41 +++++-- amm/src/new_definition.rs | 6 +- amm/src/remove.rs | 29 +++-- amm/src/swap.rs | 18 ++- amm/src/tests.rs | 245 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 319 insertions(+), 20 deletions(-) diff --git a/amm/src/add.rs b/amm/src/add.rs index 0c4a04c..73ec5a3 100644 --- a/amm/src/add.rs +++ b/amm/src/add.rs @@ -80,10 +80,16 @@ pub fn add_liquidity( ); // Calculate actual_amounts - let ideal_a: u128 = - (pool_def_data.reserve_a * max_amount_to_add_token_b) / pool_def_data.reserve_b; - let ideal_b: u128 = - (pool_def_data.reserve_b * max_amount_to_add_token_a) / pool_def_data.reserve_a; + let ideal_a: u128 = pool_def_data + .reserve_a + .checked_mul(max_amount_to_add_token_b) + .expect("reserve_a * max_amount_b overflows u128") + / pool_def_data.reserve_b; + let ideal_b: u128 = pool_def_data + .reserve_b + .checked_mul(max_amount_to_add_token_a) + .expect("reserve_b * max_amount_a overflows u128") + / pool_def_data.reserve_a; let actual_amount_a = if ideal_a > max_amount_to_add_token_a { max_amount_to_add_token_a @@ -111,8 +117,16 @@ pub fn add_liquidity( // 4. Calculate LP to mint let delta_lp = std::cmp::min( - pool_def_data.liquidity_pool_supply * actual_amount_a / pool_def_data.reserve_a, - pool_def_data.liquidity_pool_supply * actual_amount_b / pool_def_data.reserve_b, + pool_def_data + .liquidity_pool_supply + .checked_mul(actual_amount_a) + .expect("liquidity_pool_supply * actual_amount_a overflows u128") + / pool_def_data.reserve_a, + pool_def_data + .liquidity_pool_supply + .checked_mul(actual_amount_b) + .expect("liquidity_pool_supply * actual_amount_b overflows u128") + / pool_def_data.reserve_b, ); assert!(delta_lp != 0, "Payable LP must be nonzero"); @@ -125,9 +139,18 @@ pub fn add_liquidity( // 5. Update pool account let mut pool_post = pool.account.clone(); let pool_post_definition = PoolDefinition { - liquidity_pool_supply: pool_def_data.liquidity_pool_supply + delta_lp, - reserve_a: pool_def_data.reserve_a + actual_amount_a, - reserve_b: pool_def_data.reserve_b + actual_amount_b, + liquidity_pool_supply: pool_def_data + .liquidity_pool_supply + .checked_add(delta_lp) + .expect("liquidity_pool_supply + delta_lp overflows u128"), + reserve_a: pool_def_data + .reserve_a + .checked_add(actual_amount_a) + .expect("reserve_a + actual_amount_a overflows u128"), + reserve_b: pool_def_data + .reserve_b + .checked_add(actual_amount_b) + .expect("reserve_b + actual_amount_b overflows u128"), ..pool_def_data }; diff --git a/amm/src/new_definition.rs b/amm/src/new_definition.rs index c03ee61..3207dd4 100644 --- a/amm/src/new_definition.rs +++ b/amm/src/new_definition.rs @@ -91,7 +91,11 @@ pub fn new_definition( } // LP Token minting calculation - let initial_lp = (token_a_amount.get() * token_b_amount.get()).isqrt(); + let initial_lp = token_a_amount + .get() + .checked_mul(token_b_amount.get()) + .expect("token_a * token_b overflows u128") + .isqrt(); assert!( initial_lp > MINIMUM_LIQUIDITY, "Initial liquidity must exceed minimum liquidity lock" diff --git a/amm/src/remove.rs b/amm/src/remove.rs index 3c2bf08..dcd4428 100644 --- a/amm/src/remove.rs +++ b/amm/src/remove.rs @@ -94,10 +94,16 @@ pub fn remove_liquidity( "Cannot remove locked minimum liquidity" ); - let withdraw_amount_a = - (pool_def_data.reserve_a * remove_liquidity_amount) / pool_def_data.liquidity_pool_supply; - let withdraw_amount_b = - (pool_def_data.reserve_b * remove_liquidity_amount) / pool_def_data.liquidity_pool_supply; + let withdraw_amount_a = pool_def_data + .reserve_a + .checked_mul(remove_liquidity_amount) + .expect("reserve_a * remove_liquidity_amount overflows u128") + / pool_def_data.liquidity_pool_supply; + let withdraw_amount_b = pool_def_data + .reserve_b + .checked_mul(remove_liquidity_amount) + .expect("reserve_b * remove_liquidity_amount overflows u128") + / pool_def_data.liquidity_pool_supply; // 3. Validate and slippage check assert!( @@ -115,9 +121,18 @@ pub fn remove_liquidity( // 5. Update pool account let mut pool_post = pool.account.clone(); let pool_post_definition = PoolDefinition { - liquidity_pool_supply: pool_def_data.liquidity_pool_supply - delta_lp, - reserve_a: pool_def_data.reserve_a - withdraw_amount_a, - reserve_b: pool_def_data.reserve_b - withdraw_amount_b, + liquidity_pool_supply: pool_def_data + .liquidity_pool_supply + .checked_sub(delta_lp) + .expect("liquidity_pool_supply - delta_lp underflows"), + reserve_a: pool_def_data + .reserve_a + .checked_sub(withdraw_amount_a) + .expect("reserve_a - withdraw_amount_a underflows"), + reserve_b: pool_def_data + .reserve_b + .checked_sub(withdraw_amount_b) + .expect("reserve_b - withdraw_amount_b underflows"), active: true, ..pool_def_data.clone() }; diff --git a/amm/src/swap.rs b/amm/src/swap.rs index 9dae888..54ca804 100644 --- a/amm/src/swap.rs +++ b/amm/src/swap.rs @@ -76,8 +76,18 @@ fn create_swap_post_states( ) -> Vec { let mut pool_post = pool.account; let pool_post_definition = PoolDefinition { - reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, - reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, + reserve_a: pool_def_data + .reserve_a + .checked_add(deposit_a) + .expect("reserve_a + deposit_a overflows u128") + .checked_sub(withdraw_a) + .expect("reserve_a + deposit_a - withdraw_a underflows"), + reserve_b: pool_def_data + .reserve_b + .checked_add(deposit_b) + .expect("reserve_b + deposit_b overflows u128") + .checked_sub(withdraw_b) + .expect("reserve_b + deposit_b - withdraw_b underflows"), ..pool_def_data }; @@ -173,7 +183,9 @@ fn swap_logic( let withdraw_amount = reserve_withdraw_vault_amount .checked_mul(swap_amount_in) .expect("reserve * amount_in overflows u128") - / (reserve_deposit_vault_amount + swap_amount_in); + / reserve_deposit_vault_amount + .checked_add(swap_amount_in) + .expect("reserve + swap_amount_in overflows u128"); // Slippage check assert!( diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 4d67f70..d997b8c 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -2774,3 +2774,248 @@ fn test_donation_then_add_liquidity_sync_mitigates_mispricing() { assert!(synced_delta_lp < unsynced_delta_lp); } + +#[should_panic(expected = "token_a * token_b overflows u128")] +#[test] +fn new_definition_overflow_protection() { + let large_amount = u128::MAX / 2 + 1; + + let _result = new_definition( + AccountWithMetadataForTests::pool_definition_reinitializable(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_reinitializable(), + AccountWithMetadataForTests::lp_lock_holding_uninit(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_uninit(), + NonZero::new(large_amount).unwrap(), + NonZero::new(2).unwrap(), + AMM_PROGRAM_ID, + ); +} + +#[should_panic(expected = "reserve_a * max_amount_b overflows u128")] +#[test] +fn add_liquidity_overflow_protection() { + let large_reserve: u128 = u128::MAX / 2 + 1; + let reserve_b: u128 = 1_000; + + let pool = AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: 1_000, + reserve_a: large_reserve, + reserve_b, + fees: 0, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + }; + + let vault_a = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: large_reserve, + }), + nonce: Nonce(0), + }, + is_authorized: false, + account_id: IdForTests::vault_a_id(), + }; + + let vault_b = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: reserve_b, + }), + nonce: Nonce(0), + }, + is_authorized: false, + account_id: IdForTests::vault_b_id(), + }; + + let _result = add_liquidity( + pool, + vault_a, + vault_b, + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(1).unwrap(), + 500, + 2, // max_amount_b=2 → reserve_a * 2 overflows + ); +} + +#[should_panic(expected = "reserve_a * remove_liquidity_amount overflows u128")] +#[test] +fn remove_liquidity_overflow_protection() { + let large_reserve: u128 = u128::MAX / 2 + 1; + let reserve_b: u128 = 1_000; + let lp_supply: u128 = 1_002; // must exceed MINIMUM_LIQUIDITY so remove_amount=2 passes the lock check + + let pool = AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: lp_supply, + reserve_a: large_reserve, + reserve_b, + fees: 0, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + }; + + let vault_a = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: large_reserve, + }), + nonce: Nonce(0), + }, + is_authorized: false, + account_id: IdForTests::vault_a_id(), + }; + + let vault_b = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: reserve_b, + }), + nonce: Nonce(0), + }, + is_authorized: false, + account_id: IdForTests::vault_b_id(), + }; + + let user_lp = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_lp_definition_id(), + balance: 2, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::user_token_lp_id(), + }; + + let _result = remove_liquidity( + pool, + vault_a, + vault_b, + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + user_lp, + NonZero::new(2).unwrap(), // remove_amount=2 → reserve_a * 2 overflows + 1, + 1, + ); +} + +#[should_panic(expected = "reserve * amount_in overflows u128")] +#[test] +fn swap_exact_input_overflow_protection() { + let large_reserve: u128 = u128::MAX / 2 + 1; + let reserve_b: u128 = 1_000; + + let pool = AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: 1, + reserve_a: 1_000, + reserve_b: large_reserve, + fees: 0, + active: true, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + }; + + let vault_a = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: reserve_b, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::vault_a_id(), + }; + + let vault_b = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: large_reserve, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::vault_b_id(), + }; + + // Swap token_a in: withdraw_amount = reserve_b * swap_amount_in / (reserve_a + swap_amount_in) + // reserve_b is large, so reserve_b * 2 overflows + let _result = swap_exact_input( + pool, + vault_a, + vault_b, + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, + 1, + IdForTests::token_a_definition_id(), + ); +}