diff --git a/amm/src/swap.rs b/amm/src/swap.rs index fe5c544..53d5b4c 100644 --- a/amm/src/swap.rs +++ b/amm/src/swap.rs @@ -76,20 +76,34 @@ fn create_swap_post_states( deposit_b: u128, withdraw_b: u128, ) -> Vec { + let new_reserve_a = pool_def_data + .reserve_a + .checked_add(deposit_a) + .expect("reserve_a + deposit_a overflows u128") + .checked_sub(withdraw_a) + .expect("reserve_a + deposit_a - withdraw_a underflows"); + let new_reserve_b = pool_def_data + .reserve_b + .checked_add(deposit_b) + .expect("reserve_b + deposit_b overflows u128") + .checked_sub(withdraw_b) + .expect("reserve_b + deposit_b - withdraw_b underflows"); + let (old_lo, old_hi) = pool_def_data + .reserve_a + .carrying_mul(pool_def_data.reserve_b, 0); + let (new_lo, new_hi) = new_reserve_a.carrying_mul(new_reserve_b, 0); + let old_k = (old_hi, old_lo); + let new_k = (new_hi, new_lo); + + assert!( + new_k >= old_k, + "Swap invariant violation: new k must be greater than or equal to old k" + ); + let mut pool_post = pool.account; let pool_post_definition = PoolDefinition { - reserve_a: pool_def_data - .reserve_a - .checked_add(deposit_a) - .expect("reserve_a + deposit_a overflows u128") - .checked_sub(withdraw_a) - .expect("reserve_a + deposit_a - withdraw_a underflows"), - reserve_b: pool_def_data - .reserve_b - .checked_add(deposit_b) - .expect("reserve_b + deposit_b overflows u128") - .checked_sub(withdraw_b) - .expect("reserve_b + deposit_b - withdraw_b underflows"), + reserve_a: new_reserve_a, + reserve_b: new_reserve_b, ..pool_def_data }; diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 70b8434..f34ca04 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -2030,6 +2030,36 @@ fn test_call_swap_incorrect_token_type() { ); } +fn pool_with_reserves(reserve_a: u128, reserve_b: u128) -> AccountWithMetadata { + let mut pool = AccountWithMetadataForTests::pool_definition_init(); + let mut pool_definition = + PoolDefinition::try_from(&pool.account.data).expect("Pool definition must be valid"); + + pool_definition.reserve_a = reserve_a; + pool_definition.reserve_b = reserve_b; + pool.account.data = Data::from(&pool_definition); + + pool +} + +fn vault_a_with_balance(balance: u128) -> AccountWithMetadata { + let mut vault = AccountWithMetadataForTests::vault_a_init(); + vault.account.data = Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance, + }); + vault +} + +fn vault_b_with_balance(balance: u128) -> AccountWithMetadata { + let mut vault = AccountWithMetadataForTests::vault_b_init(); + vault.account.data = Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance, + }); + vault +} + #[should_panic(expected = "Vault A was not provided")] #[test] fn test_call_swap_vault_a_omitted() { @@ -2140,6 +2170,36 @@ fn test_call_swap_below_min_out() { ); } +#[test] +fn test_call_swap_widened_k_boundary() { + let old_reserve_a = u128::MAX - 2; + let old_reserve_b = u128::MAX - 1; + + assert!(old_reserve_a.checked_mul(old_reserve_b).is_none()); + + let (post_states, _chained_calls) = swap_exact_input( + pool_with_reserves(old_reserve_a, old_reserve_b), + vault_a_with_balance(old_reserve_a), + vault_b_with_balance(old_reserve_b), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 1, + 0, + IdForTests::token_a_definition_id(), + ); + + let pool_post = post_states[0].clone(); + let pool_post_definition = PoolDefinition::try_from(&pool_post.account().data) + .expect("Pool post-state must contain a valid definition"); + + assert_eq!(pool_post_definition.reserve_a, u128::MAX - 1); + assert_eq!(pool_post_definition.reserve_b, u128::MAX - 2); + assert!(pool_post_definition + .reserve_a + .checked_mul(pool_post_definition.reserve_b) + .is_none()); +} + #[test] fn test_call_swap_chained_call_successful_1() { let (post_states, chained_calls) = swap_exact_input(