From 2308681dcf16523cb089e73717dc166c1ac2b9c9 Mon Sep 17 00:00:00 2001 From: r4bbit <445106+0x-r4bbit@users.noreply.github.com> Date: Mon, 29 Jun 2026 08:56:48 +0200 Subject: [PATCH] fix(amm): compute pool arithmetic in u256 to avoid u128 overflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The AMM multiplied amounts in u128 — `token_a * token_b` for the initial LP in `new_definition`, `reserve * amount` in swaps, and the mul/div steps in add/remove liquidity. For realistic 18-decimal token amounts the intermediate product exceeds `u128::MAX` (~3.4e38): opening a pool with 100/200 tokens is `1e20 * 2e20 = 2e40`, which panicked and caused the sequencer to skip the transaction. Widen the intermediate arithmetic, not the stored types. Add `mul_div_floor`, `mul_div_ceil`, and `isqrt_product` to `amm_core` (using `alloy_primitives::U256`, as `spot_price_q64_64` already does): they compute the product/division/sqrt in U256 and downcast the result back to u128. Route `new_definition`, `swap_exact_input`/`swap_exact_output`, `add_liquidity`, and `remove_liquidity` through them. `swap_exact_output` keeps its ceil rounding (required input rounded up, in the pool's favour) via `mul_div_ceil`. Balances, reserves, and LP supply stay u128, so account data formats, IDLs, and the token/ata/stablecoin programs are unchanged. This lifts the usable amount range to the full u128. --- programs/amm/core/src/lib.rs | 129 ++++++++++++++++++++++++++ programs/amm/src/add.rs | 51 ++++++----- programs/amm/src/new_definition.rs | 12 ++- programs/amm/src/remove.rs | 28 +++--- programs/amm/src/swap.rs | 47 +++++----- programs/amm/src/tests.rs | 140 ++++++++++++++++++++++------- 6 files changed, 305 insertions(+), 102 deletions(-) diff --git a/programs/amm/core/src/lib.rs b/programs/amm/core/src/lib.rs index 905e097..feaa063 100644 --- a/programs/amm/core/src/lib.rs +++ b/programs/amm/core/src/lib.rs @@ -295,6 +295,54 @@ pub fn spot_price_q64_64(reserve_base: u128, reserve_quote: u128) -> u128 { u128::try_from(price).unwrap_or(u128::MAX) } +/// `floor(a * b / c)` computed in U256 so the `a * b` product can't overflow u128. +/// (Storage stays u128; only the intermediate widens.) +/// +/// # Panics +/// Panics if `c` is zero, or if the result exceeds u128. +#[must_use] +pub fn mul_div_floor(a: u128, b: u128, c: u128) -> u128 { + use alloy_primitives::U256; + assert!(c != 0, "mul_div_floor: divisor must be non-zero"); + let product = U256::from(a) + .checked_mul(U256::from(b)) + .expect("u128 * u128 always fits in U256"); + let result = product + .checked_div(U256::from(c)) + .expect("mul_div_floor: divisor is non-zero after the assertion above"); + u128::try_from(result).expect("mul_div_floor result exceeds u128") +} + +/// `ceil(a * b / c)` computed in U256 so the `a * b` product can't overflow u128. +/// (Storage stays u128; only the intermediate widens.) +/// +/// # Panics +/// Panics if `c` is zero, or if the result exceeds u128. +#[must_use] +pub fn mul_div_ceil(a: u128, b: u128, c: u128) -> u128 { + use alloy_primitives::U256; + assert!(c != 0, "mul_div_ceil: divisor must be non-zero"); + let product = U256::from(a) + .checked_mul(U256::from(b)) + .expect("u128 * u128 always fits in U256"); + let result = product.div_ceil(U256::from(c)); + u128::try_from(result).expect("mul_div_ceil result exceeds u128") +} + +/// `floor(sqrt(a * b))` computed in U256 so the `a * b` product can't overflow u128. +/// +/// # Panics +/// Panics if the result exceeds u128. +#[must_use] +pub fn isqrt_product(a: u128, b: u128) -> u128 { + use alloy_primitives::U256; + let product = U256::from(a) + .checked_mul(U256::from(b)) + .expect("u128 * u128 always fits in U256"); + let root = product.root(2); // ruint integer root; floor sqrt + u128::try_from(root).expect("isqrt_product result exceeds u128") +} + impl TryFrom<&Data> for PoolDefinition { type Error = std::io::Error; @@ -546,4 +594,85 @@ mod tests { fn zero_reserve_base_panics() { let _ = spot_price_q64_64(0, 1_000); } + + #[test] + fn mul_div_floor_small_cases() { + assert_eq!(mul_div_floor(6, 7, 3), 14); + // floor(7 * 7 / 3) = floor(49/3) = 16 + assert_eq!(mul_div_floor(7, 7, 3), 16); + assert_eq!(mul_div_floor(0, 12345, 7), 0); + assert_eq!(mul_div_floor(1, 1, 2), 0); + } + + #[test] + fn mul_div_floor_product_exceeds_u128() { + // 2e30 * 2e30 = 4e60, far beyond u128; / 1e20 = 4e40, still beyond u128 -- but the + // intermediate must not overflow and the *quotient* here fits once divided down. + // 2e30 * 2e30 / 2e30 = 2e30 fits in u128. + let a = 2_000_000_000_000_000_000_000_000_000_000u128; // 2e30 + assert_eq!(mul_div_floor(a, a, a), a); + // 2e30 * 2e30 / 1e20 = 4e40 would exceed u128 -- verify it panics on downcast. + } + + #[test] + #[should_panic(expected = "mul_div_floor result exceeds u128")] + fn mul_div_floor_result_exceeds_u128_panics() { + let a = 2_000_000_000_000_000_000_000_000_000_000u128; // 2e30 + let c = 100_000_000_000_000_000_000u128; // 1e20 + let _ = mul_div_floor(a, a, c); // 4e40 > u128::MAX + } + + #[test] + #[should_panic(expected = "mul_div_floor: divisor must be non-zero")] + fn mul_div_floor_zero_divisor_panics() { + let _ = mul_div_floor(1, 2, 0); + } + + #[test] + fn mul_div_ceil_small_cases() { + assert_eq!(mul_div_ceil(6, 7, 3), 14); + // ceil(7 * 7 / 3) = ceil(49/3) = 17 + assert_eq!(mul_div_ceil(7, 7, 3), 17); + // exact division: no rounding up + assert_eq!(mul_div_ceil(6, 4, 3), 8); + assert_eq!(mul_div_ceil(0, 12345, 7), 0); + } + + #[test] + fn mul_div_ceil_product_exceeds_u128() { + // (2e30 * 2e30) / 2e30 = 2e30 exactly, fits in u128. + let a = 2_000_000_000_000_000_000_000_000_000_000u128; // 2e30 + assert_eq!(mul_div_ceil(a, a, a), a); + } + + #[test] + #[should_panic(expected = "mul_div_ceil: divisor must be non-zero")] + fn mul_div_ceil_zero_divisor_panics() { + let _ = mul_div_ceil(1, 2, 0); + } + + #[test] + fn isqrt_product_matches_u128_isqrt_for_small_values() { + assert_eq!(isqrt_product(100, 100), 100); + assert_eq!(isqrt_product(2, 8), 4); + // floor(sqrt(7 * 7)) = 7, floor(sqrt(50)) = 7 + assert_eq!(isqrt_product(7, 7), 7); + assert_eq!(isqrt_product(5, 10), 50u128.isqrt()); + } + + #[test] + fn isqrt_product_handles_the_1e20_times_2e20_overflow_case() { + // 1e20 * 2e20 = 2e40 overflows u128 (max ~3.4e38); the U256 intermediate keeps it exact. + let a = 100_000_000_000_000_000_000u128; // 1e20 + let b = 200_000_000_000_000_000_000u128; // 2e20 + // floor(sqrt(2e40)) computed independently in U256. + let expected = { + use alloy_primitives::U256; + let product = U256::from(a).checked_mul(U256::from(b)).unwrap(); + u128::try_from(product.root(2)).unwrap() + }; + assert_eq!(isqrt_product(a, b), expected); + // Sanity: floor(sqrt(2e40)) = floor(1.4142...e20) = 141421356237309504880. + assert_eq!(isqrt_product(a, b), 141_421_356_237_309_504_880); + } } diff --git a/programs/amm/src/add.rs b/programs/amm/src/add.rs index 3386b2b..80d03b3 100644 --- a/programs/amm/src/add.rs +++ b/programs/amm/src/add.rs @@ -2,8 +2,8 @@ use std::num::NonZeroU128; use amm_core::{ assert_supported_fee_tier, compute_config_pda, compute_liquidity_token_pda_seed, - compute_pool_pda_seed, read_vault_fungible_balances, spot_price_q64_64, AmmConfig, - PoolDefinition, + compute_pool_pda_seed, mul_div_floor, read_vault_fungible_balances, spot_price_q64_64, + AmmConfig, PoolDefinition, }; use clock_core::CLOCK_01_PROGRAM_ACCOUNT_ID; use nssa_core::{ @@ -113,18 +113,18 @@ pub fn add_liquidity( assert!(pool_def_data.reserve_a != 0, "Reserves must be nonzero"); assert!(pool_def_data.reserve_b != 0, "Reserves must be nonzero"); - let ideal_a: u128 = pool_def_data - .reserve_a - .checked_mul(max_amount_to_add_token_b) - .expect("reserve_a * max_amount_b overflows u128") - .checked_div(pool_def_data.reserve_b) - .expect("reserve_b must be nonzero after validation"); - let ideal_b: u128 = pool_def_data - .reserve_b - .checked_mul(max_amount_to_add_token_a) - .expect("reserve_b * max_amount_a overflows u128") - .checked_div(pool_def_data.reserve_a) - .expect("reserve_a must be nonzero after validation"); + // floor(reserve * max_amount / reserve), products widened to U256. Reserves are nonzero + // (asserted above), so the divisors are valid. + let ideal_a: u128 = mul_div_floor( + pool_def_data.reserve_a, + max_amount_to_add_token_b, + pool_def_data.reserve_b, + ); + let ideal_b: u128 = mul_div_floor( + pool_def_data.reserve_b, + max_amount_to_add_token_a, + pool_def_data.reserve_a, + ); let actual_amount_a = if ideal_a > max_amount_to_add_token_a { max_amount_to_add_token_a @@ -151,19 +151,18 @@ pub fn add_liquidity( assert!(actual_amount_b != 0, "A trade amount is 0"); // 4. Calculate LP to mint + // floor(supply * actual / reserve), products widened to U256. let delta_lp = std::cmp::min( - pool_def_data - .liquidity_pool_supply - .checked_mul(actual_amount_a) - .expect("liquidity_pool_supply * actual_amount_a overflows u128") - .checked_div(pool_def_data.reserve_a) - .expect("reserve_a must be nonzero after validation"), - pool_def_data - .liquidity_pool_supply - .checked_mul(actual_amount_b) - .expect("liquidity_pool_supply * actual_amount_b overflows u128") - .checked_div(pool_def_data.reserve_b) - .expect("reserve_b must be nonzero after validation"), + mul_div_floor( + pool_def_data.liquidity_pool_supply, + actual_amount_a, + pool_def_data.reserve_a, + ), + mul_div_floor( + pool_def_data.liquidity_pool_supply, + actual_amount_b, + pool_def_data.reserve_b, + ), ); assert!(delta_lp != 0, "Payable LP must be nonzero"); diff --git a/programs/amm/src/new_definition.rs b/programs/amm/src/new_definition.rs index a78c221..02ce020 100644 --- a/programs/amm/src/new_definition.rs +++ b/programs/amm/src/new_definition.rs @@ -4,7 +4,8 @@ use amm_core::{ assert_supported_fee_tier, compute_config_pda, compute_liquidity_token_pda, compute_liquidity_token_pda_seed, compute_lp_lock_holding_pda, compute_lp_lock_holding_pda_seed, compute_pool_pda, compute_pool_pda_seed, compute_vault_pda, - compute_vault_pda_seed, spot_price_q64_64, AmmConfig, PoolDefinition, MINIMUM_LIQUIDITY, + compute_vault_pda_seed, isqrt_product, spot_price_q64_64, AmmConfig, PoolDefinition, + MINIMUM_LIQUIDITY, }; use clock_core::CLOCK_01_PROGRAM_ACCOUNT_ID; use nssa_core::{ @@ -117,12 +118,9 @@ pub fn new_definition( "New definition: clock account must be the canonical 1-block LEZ clock account" ); - // LP Token minting calculation - let initial_lp = token_a_amount - .get() - .checked_mul(token_b_amount.get()) - .expect("token_a * token_b overflows u128") - .isqrt(); + // LP Token minting calculation. The `token_a * token_b` product is computed in U256 (via + // `isqrt_product`) so realistic 18-decimal amounts can't overflow u128 before the sqrt. + let initial_lp = isqrt_product(token_a_amount.get(), token_b_amount.get()); assert!( initial_lp > MINIMUM_LIQUIDITY, "Initial liquidity must exceed minimum liquidity lock" diff --git a/programs/amm/src/remove.rs b/programs/amm/src/remove.rs index 5412ae1..2d0bbf4 100644 --- a/programs/amm/src/remove.rs +++ b/programs/amm/src/remove.rs @@ -2,8 +2,8 @@ use std::num::NonZeroU128; use amm_core::{ assert_supported_fee_tier, compute_config_pda, compute_liquidity_token_pda_seed, - compute_pool_pda_seed, compute_vault_pda_seed, spot_price_q64_64, AmmConfig, PoolDefinition, - MINIMUM_LIQUIDITY, + compute_pool_pda_seed, compute_vault_pda_seed, mul_div_floor, spot_price_q64_64, AmmConfig, + PoolDefinition, MINIMUM_LIQUIDITY, }; use clock_core::CLOCK_01_PROGRAM_ACCOUNT_ID; use nssa_core::{ @@ -156,18 +156,18 @@ pub fn remove_liquidity( "Cannot remove locked minimum liquidity" ); - let withdraw_amount_a = pool_def_data - .reserve_a - .checked_mul(remove_liquidity_amount) - .expect("reserve_a * remove_liquidity_amount overflows u128") - .checked_div(pool_def_data.liquidity_pool_supply) - .expect("liquidity supply must be nonzero after validation"); - let withdraw_amount_b = pool_def_data - .reserve_b - .checked_mul(remove_liquidity_amount) - .expect("reserve_b * remove_liquidity_amount overflows u128") - .checked_div(pool_def_data.liquidity_pool_supply) - .expect("liquidity supply must be nonzero after validation"); + // floor(reserve * remove_amount / supply), products widened to U256. Supply exceeds + // MINIMUM_LIQUIDITY (asserted above), so the divisor is nonzero. + let withdraw_amount_a = mul_div_floor( + pool_def_data.reserve_a, + remove_liquidity_amount, + pool_def_data.liquidity_pool_supply, + ); + let withdraw_amount_b = mul_div_floor( + pool_def_data.reserve_b, + remove_liquidity_amount, + pool_def_data.liquidity_pool_supply, + ); // 3. Validate and slippage check assert!( diff --git a/programs/amm/src/swap.rs b/programs/amm/src/swap.rs index e869c55..d6a1d2e 100644 --- a/programs/amm/src/swap.rs +++ b/programs/amm/src/swap.rs @@ -1,6 +1,6 @@ use amm_core::{ - assert_supported_fee_tier, compute_config_pda, compute_pool_pda_seed, - read_vault_fungible_balances, spot_price_q64_64, AmmConfig, FEE_BPS_DENOMINATOR, + assert_supported_fee_tier, compute_config_pda, compute_pool_pda_seed, mul_div_ceil, + mul_div_floor, read_vault_fungible_balances, spot_price_q64_64, AmmConfig, FEE_BPS_DENOMINATOR, MINIMUM_LIQUIDITY, }; pub use amm_core::{compute_liquidity_token_pda_seed, compute_vault_pda_seed, PoolDefinition}; @@ -270,11 +270,8 @@ fn swap_logic( let fee_multiplier = FEE_BPS_DENOMINATOR .checked_sub(fee_bps) .expect("fee_bps exceeds fee denominator"); - let effective_amount_in = swap_amount_in - .checked_mul(fee_multiplier) - .expect("swap_amount_in * (FEE_BPS_DENOMINATOR - fee_bps) overflows u128") - .checked_div(FEE_BPS_DENOMINATOR) - .expect("fee denominator must be nonzero"); + // floor(swap_amount_in * fee_multiplier / FEE_BPS_DENOMINATOR), product widened to U256. + let effective_amount_in = mul_div_floor(swap_amount_in, fee_multiplier, FEE_BPS_DENOMINATOR); assert!( effective_amount_in != 0, "Effective swap amount should be nonzero" @@ -283,15 +280,16 @@ fn swap_logic( // The recorded pool reserves are updated later with the full // `swap_amount_in`, so LP fees accrue inside `reserve_*` via invariant // growth rather than as a separate vault balance surplus over `reserve_*`. - let withdraw_amount = reserve_withdraw_vault_amount - .checked_mul(effective_amount_in) - .expect("reserve * effective_amount_in overflows u128") - .checked_div( - reserve_deposit_vault_amount - .checked_add(effective_amount_in) - .expect("reserve + effective_amount_in overflows u128"), - ) - .expect("reserve plus effective input must be nonzero"); + // The denominator sum stays u128 (overflows only near u128::MAX, an unstorable reserve); + // only the `reserve * effective` product is widened to U256. + let reserve_plus_effective = reserve_deposit_vault_amount + .checked_add(effective_amount_in) + .expect("reserve + effective_amount_in overflows u128"); + let withdraw_amount = mul_div_floor( + reserve_withdraw_vault_amount, + effective_amount_in, + reserve_plus_effective, + ); // Slippage check assert!( @@ -483,23 +481,24 @@ fn exact_output_swap_logic( // // Solve constant product for effective_in (fee already removed): // effective_in >= ceil(reserve_in * amount_out / (reserve_out - amount_out)) - let effective_in_numerator = reserve_deposit_vault_amount - .checked_mul(exact_amount_out) - .expect("reserve * amount_out overflows u128"); + // ceil(reserve_in * amount_out / (reserve_out - amount_out)). The `reserve_in * amount_out` + // product is widened to U256; the denominator is a subtraction that stays u128. let effective_in_denominator = reserve_withdraw_vault_amount .checked_sub(exact_amount_out) .expect("reserve_out - amount_out underflows"); - let effective_in_min = effective_in_numerator.div_ceil(effective_in_denominator); + let effective_in_min = mul_div_ceil( + reserve_deposit_vault_amount, + exact_amount_out, + effective_in_denominator, + ); // Lift back to gross input so that // floor(gross_in * (FEE_DENOM - fee) / FEE_DENOM) >= effective_in_min + // ceil(effective_in_min * FEE_BPS_DENOMINATOR / fee_multiplier), product widened to U256. let fee_multiplier = FEE_BPS_DENOMINATOR .checked_sub(fee_bps) .expect("fee_bps exceeds fee denominator"); - let deposit_amount = effective_in_min - .checked_mul(FEE_BPS_DENOMINATOR) - .expect("effective_in * FEE_DENOM overflows u128") - .div_ceil(fee_multiplier); + let deposit_amount = mul_div_ceil(effective_in_min, FEE_BPS_DENOMINATOR, fee_multiplier); // Slippage check assert!( diff --git a/programs/amm/src/tests.rs b/programs/amm/src/tests.rs index d6103e5..3878b5e 100644 --- a/programs/amm/src/tests.rs +++ b/programs/amm/src/tests.rs @@ -10,9 +10,9 @@ use std::num::NonZero; use amm_core::{ compute_config_pda, compute_liquidity_token_pda, compute_liquidity_token_pda_seed, compute_lp_lock_holding_pda, compute_lp_lock_holding_pda_seed, compute_pool_pda, - compute_pool_pda_seed, compute_vault_pda, compute_vault_pda_seed, AmmConfig, PoolDefinition, - FEE_BPS_DENOMINATOR, FEE_TIER_BPS_1, FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5, - MINIMUM_LIQUIDITY, + compute_pool_pda_seed, compute_vault_pda, compute_vault_pda_seed, isqrt_product, mul_div_floor, + AmmConfig, PoolDefinition, FEE_BPS_DENOMINATOR, FEE_TIER_BPS_1, FEE_TIER_BPS_100, + FEE_TIER_BPS_30, FEE_TIER_BPS_5, MINIMUM_LIQUIDITY, }; use nssa_core::{ account::{Account, AccountId, AccountWithMetadata, Data, Nonce}, @@ -3107,10 +3107,12 @@ fn call_swap_exact_output_accepts_smallest_max_in_for_rounded_boundary() { ); } -// Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, -// making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes, -// so an attacker receives `exact_amount_out` tokens while paying nothing. -#[should_panic(expected = "reserve * amount_out overflows u128")] +// Without widening, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, making +// `deposit_amount = 0`, so an attacker would receive `exact_amount_out` tokens while paying +// nothing. Under Option A the product is computed in U256, so the true (enormous) required input +// is computed exactly and the slippage check `deposit_amount <= max_amount_in` correctly rejects +// the attacker's tiny `max_amount_in`. +#[should_panic(expected = "Required input exceeds maximum amount in")] #[test] fn swap_exact_output_overflow_protection() { // reserve_a chosen so that reserve_a * 2 overflows u128: @@ -3286,6 +3288,42 @@ fn test_new_definition_lp_symmetric_amounts() { assert_eq!(chained_call_lp_user, expected_lp_user_call); } +#[test] +fn test_new_definition_large_18_decimal_amounts_no_overflow() { + // 100e18 and 200e18 (1e20 / 2e20). The naive `token_a * token_b` product is 2e40, which + // overflows u128 (max ~3.4e38) and previously panicked. `isqrt_product` widens the product + // to U256, so this must now succeed. Expected LP = floor(sqrt(2e40)). + let token_a_amount = 100_000_000_000_000_000_000u128; // 1e20 + let token_b_amount = 200_000_000_000_000_000_000u128; // 2e20 + let expected_lp = isqrt_product(token_a_amount, token_b_amount); + assert_eq!(expected_lp, 141_421_356_237_309_504_880); + + let (post_states, _chained_calls) = new_definition( + AccountWithMetadataForTests::config_init(), + AccountWithMetadataForTests::pool_definition_uninit(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_uninit(), + AccountWithMetadataForTests::lp_lock_holding_uninit(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_uninit(), + AccountWithMetadataForTests::current_tick_account_uninit(), + AccountWithMetadataForTests::clock(), + NonZero::new(token_a_amount).unwrap(), + NonZero::new(token_b_amount).unwrap(), + BalanceForTests::fee_tier(), + AMM_PROGRAM_ID, + ); + + let pool_post = post_states[1].clone(); + let pool_def = PoolDefinition::try_from(&pool_post.account().data).unwrap(); + assert_eq!(pool_def.reserve_a, token_a_amount); + assert_eq!(pool_def.reserve_b, token_b_amount); + assert_eq!(pool_def.liquidity_pool_supply, expected_lp); + assert!(pool_def.liquidity_pool_supply > MINIMUM_LIQUIDITY); +} + #[test] fn test_minimum_liquidity_lock_and_remove_all_user_lp() { let pool_uninitialized = AccountWithMetadata { @@ -3552,12 +3590,14 @@ fn test_donation_then_add_liquidity_sync_mitigates_mispricing() { assert!(synced_delta_lp < unsynced_delta_lp); } -#[should_panic(expected = "token_a * token_b overflows u128")] +// Under Option A the `token_a * token_b` product is computed in U256, so a product that exceeds +// u128 no longer panics: it is square-rooted exactly. Here `large_amount * 2 = 2^128`, whose +// integer sqrt is `2^64`. Previously this multiplication overflowed u128 and panicked. #[test] fn new_definition_overflow_protection() { - let large_amount = u128::MAX / 2 + 1; + let large_amount = u128::MAX / 2 + 1; // 2^127 - let _result = new_definition( + let (post_states, _chained_calls) = new_definition( AccountWithMetadataForTests::config_init(), AccountWithMetadataForTests::pool_definition_uninit(), AccountWithMetadataForTests::vault_a_init(), @@ -3574,13 +3614,21 @@ fn new_definition_overflow_protection() { BalanceForTests::fee_tier(), AMM_PROGRAM_ID, ); + + let pool_def = PoolDefinition::try_from(&post_states[1].account().data).unwrap(); + // floor(sqrt(2^127 * 2)) = floor(sqrt(2^128)) = 2^64. + assert_eq!(pool_def.liquidity_pool_supply, 1u128 << 64); + assert_eq!(pool_def.reserve_a, large_amount); + assert_eq!(pool_def.reserve_b, 2); } -#[should_panic(expected = "reserve_a * max_amount_b overflows u128")] +// Under Option A the `reserve * max_amount` and `supply * actual` products are computed in U256, so +// realistic large reserves no longer overflow u128. Here every product (reserve_a * max_b, +// supply * actual, etc.) is `1e30 * 1e30 = 1e60`, far beyond u128 (max ~3.4e38), yet the add +// succeeds and computes the correct widened results. Previously these multiplications panicked. #[test] fn add_liquidity_overflow_protection() { - let large_reserve: u128 = u128::MAX / 2 + 1; - let reserve_b: u128 = 1_000; + let large: u128 = 1_000_000_000_000_000_000_000_000_000_000; // 1e30 let pool = AccountWithMetadata { account: Account { @@ -3592,9 +3640,9 @@ fn add_liquidity_overflow_protection() { vault_a_id: IdForTests::vault_a_id(), vault_b_id: IdForTests::vault_b_id(), liquidity_pool_id: IdForTests::token_lp_definition_id(), - liquidity_pool_supply: MINIMUM_LIQUIDITY, - reserve_a: large_reserve, - reserve_b, + liquidity_pool_supply: large, + reserve_a: large, + reserve_b: large, fees: BalanceForTests::fee_tier(), }), nonce: Nonce(0), @@ -3609,7 +3657,7 @@ fn add_liquidity_overflow_protection() { balance: 0, data: Data::from(&TokenHolding::Fungible { definition_id: IdForTests::token_a_definition_id(), - balance: large_reserve, + balance: large, }), nonce: Nonce(0), }, @@ -3623,7 +3671,7 @@ fn add_liquidity_overflow_protection() { balance: 0, data: Data::from(&TokenHolding::Fungible { definition_id: IdForTests::token_b_definition_id(), - balance: reserve_b, + balance: large, }), nonce: Nonce(0), }, @@ -3631,7 +3679,7 @@ fn add_liquidity_overflow_protection() { account_id: IdForTests::vault_b_id(), }; - let _result = add_liquidity( + let (post_states, _chained_calls) = add_liquidity( AccountWithMetadataForTests::config_init(), pool, vault_a, @@ -3643,16 +3691,24 @@ fn add_liquidity_overflow_protection() { AccountWithMetadataForTests::current_tick_account_uninit(), AccountWithMetadataForTests::clock(), NonZero::new(1).unwrap(), - 500, - 2, // max_amount_b=2 → reserve_a * 2 overflows + large, // max_amount_a + large, // max_amount_b AMM_PROGRAM_ID, ); + + let pool_def = PoolDefinition::try_from(&post_states[1].account().data).unwrap(); + // Balanced add of `1e30` to each `1e30` reserve mints `delta_lp = 1e30`. + assert_eq!(pool_def.reserve_a, large + large); + assert_eq!(pool_def.reserve_b, large + large); + assert_eq!(pool_def.liquidity_pool_supply, large + large); } -#[should_panic(expected = "reserve_a * remove_liquidity_amount overflows u128")] +// Under Option A the `reserve * remove_amount` product is computed in U256, so a product that +// exceeds u128 no longer panics: `large_reserve * 2 = 2^128` is divided down to a valid u128 +// withdraw. Previously this multiplication overflowed u128 and panicked. #[test] fn remove_liquidity_overflow_protection() { - let large_reserve: u128 = u128::MAX / 2 + 1; + let large_reserve: u128 = u128::MAX / 2 + 1; // 2^127 let reserve_b: u128 = 1_000; let lp_supply: u128 = 1_002; // must exceed MINIMUM_LIQUIDITY so remove_amount=2 passes the lock check @@ -3719,7 +3775,7 @@ fn remove_liquidity_overflow_protection() { account_id: IdForTests::user_token_lp_id(), }; - let _result = remove_liquidity( + let (post_states, _chained_calls) = remove_liquidity( AccountWithMetadataForTests::config_init(), pool, vault_a, @@ -3730,18 +3786,27 @@ fn remove_liquidity_overflow_protection() { user_lp, AccountWithMetadataForTests::current_tick_account_uninit(), AccountWithMetadataForTests::clock(), - NonZero::new(2).unwrap(), /* remove_amount=2 → reserve_a * 2 - * overflows */ + NonZero::new(2).unwrap(), // remove_amount=2 → reserve_a * 2 = 2^128 (widened to U256) 1, 1, AMM_PROGRAM_ID, ); + + // withdraw_a = floor(reserve_a * 2 / supply); withdraw_b = floor(1000 * 2 / 1002) = 1. + let expected_withdraw_a = mul_div_floor(large_reserve, 2, lp_supply); + let expected_withdraw_b = mul_div_floor(reserve_b, 2, lp_supply); + let pool_def = PoolDefinition::try_from(&post_states[1].account().data).unwrap(); + assert_eq!(pool_def.reserve_a, large_reserve - expected_withdraw_a); + assert_eq!(pool_def.reserve_b, reserve_b - expected_withdraw_b); + assert_eq!(pool_def.liquidity_pool_supply, lp_supply - 2); } -#[should_panic(expected = "reserve * effective_amount_in overflows u128")] +// Under Option A the `reserve_out * effective_amount_in` product is computed in U256, so a product +// that exceeds u128 no longer panics: `reserve_b * 2 = 2^128` is divided down to a valid u128 +// withdraw. Previously this multiplication overflowed u128 and panicked. #[test] fn swap_exact_input_overflow_protection() { - let large_reserve: u128 = u128::MAX / 2 + 1; + let large_reserve: u128 = u128::MAX / 2 + 1; // 2^127 let reserve_b: u128 = 1_000; let pool = AccountWithMetadata { @@ -3794,9 +3859,9 @@ fn swap_exact_input_overflow_protection() { }; // Swap token_a in: withdraw_amount = reserve_b * effective_amount_in / (reserve_a + - // effective_amount_in) With fee_bps=30: effective_amount_in = 3 * 9970 / 10000 = 2 - // reserve_b is large, so reserve_b * 2 overflows - let _result = swap_exact_input( + // effective_amount_in). With fee_bps=30: effective_amount_in = floor(3 * 9970 / 10000) = 2. + // reserve_b is large, so `reserve_b * 2 = 2^128` is widened to U256 rather than overflowing. + let (post_states, _chained_calls) = swap_exact_input( AccountWithMetadataForTests::config_init(), pool, vault_a, @@ -3810,6 +3875,19 @@ fn swap_exact_input_overflow_protection() { IdForTests::token_a_definition_id(), AMM_PROGRAM_ID, ); + + let fee_multiplier = FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier(); + let effective_amount_in = mul_div_floor(3, fee_multiplier, FEE_BPS_DENOMINATOR); + let expected_withdraw = mul_div_floor( + large_reserve, + effective_amount_in, + 1_000 + effective_amount_in, + ); + let pool_def = PoolDefinition::try_from(&post_states[1].account().data).unwrap(); + // token_a in: reserve_a grows by the full swap_amount_in (3); reserve_b shrinks by the + // withdraw. + assert_eq!(pool_def.reserve_a, 1_000 + 3); + assert_eq!(pool_def.reserve_b, large_reserve - expected_withdraw); } #[test]