fix(amm): use checked mul/add/sub to avoid overflows/underflows

This commit is contained in:
Andrea Franz 2026-04-07 10:38:14 +02:00 committed by r4bbit
parent 1dbc1dc411
commit 46643941ac
No known key found for this signature in database
GPG Key ID: E95F1E9447DC91A9
5 changed files with 319 additions and 20 deletions

View File

@ -80,10 +80,16 @@ pub fn add_liquidity(
); );
// Calculate actual_amounts // Calculate actual_amounts
let ideal_a: u128 = let ideal_a: u128 = pool_def_data
(pool_def_data.reserve_a * max_amount_to_add_token_b) / pool_def_data.reserve_b; .reserve_a
let ideal_b: u128 = .checked_mul(max_amount_to_add_token_b)
(pool_def_data.reserve_b * max_amount_to_add_token_a) / pool_def_data.reserve_a; .expect("reserve_a * max_amount_b overflows u128")
/ pool_def_data.reserve_b;
let ideal_b: u128 = pool_def_data
.reserve_b
.checked_mul(max_amount_to_add_token_a)
.expect("reserve_b * max_amount_a overflows u128")
/ pool_def_data.reserve_a;
let actual_amount_a = if ideal_a > max_amount_to_add_token_a { let actual_amount_a = if ideal_a > max_amount_to_add_token_a {
max_amount_to_add_token_a max_amount_to_add_token_a
@ -111,8 +117,16 @@ pub fn add_liquidity(
// 4. Calculate LP to mint // 4. Calculate LP to mint
let delta_lp = std::cmp::min( let delta_lp = std::cmp::min(
pool_def_data.liquidity_pool_supply * actual_amount_a / pool_def_data.reserve_a, pool_def_data
pool_def_data.liquidity_pool_supply * actual_amount_b / pool_def_data.reserve_b, .liquidity_pool_supply
.checked_mul(actual_amount_a)
.expect("liquidity_pool_supply * actual_amount_a overflows u128")
/ pool_def_data.reserve_a,
pool_def_data
.liquidity_pool_supply
.checked_mul(actual_amount_b)
.expect("liquidity_pool_supply * actual_amount_b overflows u128")
/ pool_def_data.reserve_b,
); );
assert!(delta_lp != 0, "Payable LP must be nonzero"); assert!(delta_lp != 0, "Payable LP must be nonzero");
@ -125,9 +139,18 @@ pub fn add_liquidity(
// 5. Update pool account // 5. Update pool account
let mut pool_post = pool.account.clone(); let mut pool_post = pool.account.clone();
let pool_post_definition = PoolDefinition { let pool_post_definition = PoolDefinition {
liquidity_pool_supply: pool_def_data.liquidity_pool_supply + delta_lp, liquidity_pool_supply: pool_def_data
reserve_a: pool_def_data.reserve_a + actual_amount_a, .liquidity_pool_supply
reserve_b: pool_def_data.reserve_b + actual_amount_b, .checked_add(delta_lp)
.expect("liquidity_pool_supply + delta_lp overflows u128"),
reserve_a: pool_def_data
.reserve_a
.checked_add(actual_amount_a)
.expect("reserve_a + actual_amount_a overflows u128"),
reserve_b: pool_def_data
.reserve_b
.checked_add(actual_amount_b)
.expect("reserve_b + actual_amount_b overflows u128"),
..pool_def_data ..pool_def_data
}; };

View File

@ -91,7 +91,11 @@ pub fn new_definition(
} }
// LP Token minting calculation // LP Token minting calculation
let initial_lp = (token_a_amount.get() * token_b_amount.get()).isqrt(); let initial_lp = token_a_amount
.get()
.checked_mul(token_b_amount.get())
.expect("token_a * token_b overflows u128")
.isqrt();
assert!( assert!(
initial_lp > MINIMUM_LIQUIDITY, initial_lp > MINIMUM_LIQUIDITY,
"Initial liquidity must exceed minimum liquidity lock" "Initial liquidity must exceed minimum liquidity lock"

View File

@ -94,10 +94,16 @@ pub fn remove_liquidity(
"Cannot remove locked minimum liquidity" "Cannot remove locked minimum liquidity"
); );
let withdraw_amount_a = let withdraw_amount_a = pool_def_data
(pool_def_data.reserve_a * remove_liquidity_amount) / pool_def_data.liquidity_pool_supply; .reserve_a
let withdraw_amount_b = .checked_mul(remove_liquidity_amount)
(pool_def_data.reserve_b * remove_liquidity_amount) / pool_def_data.liquidity_pool_supply; .expect("reserve_a * remove_liquidity_amount overflows u128")
/ pool_def_data.liquidity_pool_supply;
let withdraw_amount_b = pool_def_data
.reserve_b
.checked_mul(remove_liquidity_amount)
.expect("reserve_b * remove_liquidity_amount overflows u128")
/ pool_def_data.liquidity_pool_supply;
// 3. Validate and slippage check // 3. Validate and slippage check
assert!( assert!(
@ -115,9 +121,18 @@ pub fn remove_liquidity(
// 5. Update pool account // 5. Update pool account
let mut pool_post = pool.account.clone(); let mut pool_post = pool.account.clone();
let pool_post_definition = PoolDefinition { let pool_post_definition = PoolDefinition {
liquidity_pool_supply: pool_def_data.liquidity_pool_supply - delta_lp, liquidity_pool_supply: pool_def_data
reserve_a: pool_def_data.reserve_a - withdraw_amount_a, .liquidity_pool_supply
reserve_b: pool_def_data.reserve_b - withdraw_amount_b, .checked_sub(delta_lp)
.expect("liquidity_pool_supply - delta_lp underflows"),
reserve_a: pool_def_data
.reserve_a
.checked_sub(withdraw_amount_a)
.expect("reserve_a - withdraw_amount_a underflows"),
reserve_b: pool_def_data
.reserve_b
.checked_sub(withdraw_amount_b)
.expect("reserve_b - withdraw_amount_b underflows"),
active: true, active: true,
..pool_def_data.clone() ..pool_def_data.clone()
}; };

View File

@ -76,8 +76,18 @@ fn create_swap_post_states(
) -> Vec<AccountPostState> { ) -> Vec<AccountPostState> {
let mut pool_post = pool.account; let mut pool_post = pool.account;
let pool_post_definition = PoolDefinition { let pool_post_definition = PoolDefinition {
reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, reserve_a: pool_def_data
reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, .reserve_a
.checked_add(deposit_a)
.expect("reserve_a + deposit_a overflows u128")
.checked_sub(withdraw_a)
.expect("reserve_a + deposit_a - withdraw_a underflows"),
reserve_b: pool_def_data
.reserve_b
.checked_add(deposit_b)
.expect("reserve_b + deposit_b overflows u128")
.checked_sub(withdraw_b)
.expect("reserve_b + deposit_b - withdraw_b underflows"),
..pool_def_data ..pool_def_data
}; };
@ -173,7 +183,9 @@ fn swap_logic(
let withdraw_amount = reserve_withdraw_vault_amount let withdraw_amount = reserve_withdraw_vault_amount
.checked_mul(swap_amount_in) .checked_mul(swap_amount_in)
.expect("reserve * amount_in overflows u128") .expect("reserve * amount_in overflows u128")
/ (reserve_deposit_vault_amount + swap_amount_in); / reserve_deposit_vault_amount
.checked_add(swap_amount_in)
.expect("reserve + swap_amount_in overflows u128");
// Slippage check // Slippage check
assert!( assert!(

View File

@ -2774,3 +2774,248 @@ fn test_donation_then_add_liquidity_sync_mitigates_mispricing() {
assert!(synced_delta_lp < unsynced_delta_lp); assert!(synced_delta_lp < unsynced_delta_lp);
} }
#[should_panic(expected = "token_a * token_b overflows u128")]
#[test]
fn new_definition_overflow_protection() {
let large_amount = u128::MAX / 2 + 1;
let _result = new_definition(
AccountWithMetadataForTests::pool_definition_reinitializable(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_reinitializable(),
AccountWithMetadataForTests::lp_lock_holding_uninit(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_uninit(),
NonZero::new(large_amount).unwrap(),
NonZero::new(2).unwrap(),
AMM_PROGRAM_ID,
);
}
#[should_panic(expected = "reserve_a * max_amount_b overflows u128")]
#[test]
fn add_liquidity_overflow_protection() {
let large_reserve: u128 = u128::MAX / 2 + 1;
let reserve_b: u128 = 1_000;
let pool = AccountWithMetadata {
account: Account {
program_owner: ProgramId::default(),
balance: 0,
data: Data::from(&PoolDefinition {
definition_token_a_id: IdForTests::token_a_definition_id(),
definition_token_b_id: IdForTests::token_b_definition_id(),
vault_a_id: IdForTests::vault_a_id(),
vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: 1_000,
reserve_a: large_reserve,
reserve_b,
fees: 0,
active: true,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::pool_definition_id(),
};
let vault_a = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_a_definition_id(),
balance: large_reserve,
}),
nonce: Nonce(0),
},
is_authorized: false,
account_id: IdForTests::vault_a_id(),
};
let vault_b = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_b_definition_id(),
balance: reserve_b,
}),
nonce: Nonce(0),
},
is_authorized: false,
account_id: IdForTests::vault_b_id(),
};
let _result = add_liquidity(
pool,
vault_a,
vault_b,
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(1).unwrap(),
500,
2, // max_amount_b=2 → reserve_a * 2 overflows
);
}
#[should_panic(expected = "reserve_a * remove_liquidity_amount overflows u128")]
#[test]
fn remove_liquidity_overflow_protection() {
let large_reserve: u128 = u128::MAX / 2 + 1;
let reserve_b: u128 = 1_000;
let lp_supply: u128 = 1_002; // must exceed MINIMUM_LIQUIDITY so remove_amount=2 passes the lock check
let pool = AccountWithMetadata {
account: Account {
program_owner: ProgramId::default(),
balance: 0,
data: Data::from(&PoolDefinition {
definition_token_a_id: IdForTests::token_a_definition_id(),
definition_token_b_id: IdForTests::token_b_definition_id(),
vault_a_id: IdForTests::vault_a_id(),
vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: lp_supply,
reserve_a: large_reserve,
reserve_b,
fees: 0,
active: true,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::pool_definition_id(),
};
let vault_a = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_a_definition_id(),
balance: large_reserve,
}),
nonce: Nonce(0),
},
is_authorized: false,
account_id: IdForTests::vault_a_id(),
};
let vault_b = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_b_definition_id(),
balance: reserve_b,
}),
nonce: Nonce(0),
},
is_authorized: false,
account_id: IdForTests::vault_b_id(),
};
let user_lp = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_lp_definition_id(),
balance: 2,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::user_token_lp_id(),
};
let _result = remove_liquidity(
pool,
vault_a,
vault_b,
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
user_lp,
NonZero::new(2).unwrap(), // remove_amount=2 → reserve_a * 2 overflows
1,
1,
);
}
#[should_panic(expected = "reserve * amount_in overflows u128")]
#[test]
fn swap_exact_input_overflow_protection() {
let large_reserve: u128 = u128::MAX / 2 + 1;
let reserve_b: u128 = 1_000;
let pool = AccountWithMetadata {
account: Account {
program_owner: ProgramId::default(),
balance: 0,
data: Data::from(&PoolDefinition {
definition_token_a_id: IdForTests::token_a_definition_id(),
definition_token_b_id: IdForTests::token_b_definition_id(),
vault_a_id: IdForTests::vault_a_id(),
vault_b_id: IdForTests::vault_b_id(),
liquidity_pool_id: IdForTests::token_lp_definition_id(),
liquidity_pool_supply: 1,
reserve_a: 1_000,
reserve_b: large_reserve,
fees: 0,
active: true,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::pool_definition_id(),
};
let vault_a = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_a_definition_id(),
balance: reserve_b,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::vault_a_id(),
};
let vault_b = AccountWithMetadata {
account: Account {
program_owner: TOKEN_PROGRAM_ID,
balance: 0,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_b_definition_id(),
balance: large_reserve,
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::vault_b_id(),
};
// Swap token_a in: withdraw_amount = reserve_b * swap_amount_in / (reserve_a + swap_amount_in)
// reserve_b is large, so reserve_b * 2 overflows
let _result = swap_exact_input(
pool,
vault_a,
vault_b,
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b(),
2,
1,
IdForTests::token_a_definition_id(),
);
}