fixed omitted vault checks

Previous versions mistakenly used token_definition instead of vault account id to check for vaults.

Functions and corresponding tests fixed.

Minor error in arithmetic in state.rs for remove_liquidity fixed.
This commit is contained in:
jonesmarvin8 2025-11-28 08:11:04 -05:00
parent 863ed888ad
commit 3e9f0f9384
2 changed files with 55 additions and 64 deletions

View File

@ -335,20 +335,18 @@ fn swap(
// Verify vaults are in fact vaults // Verify vaults are in fact vaults
let pool_def_data = PoolDefinition::parse(&pool.account.data).unwrap(); let pool_def_data = PoolDefinition::parse(&pool.account.data).unwrap();
let vault1_data = TokenHolding::parse(&vault1.account.data).unwrap();
let vault2_data = TokenHolding::parse(&vault2.account.data).unwrap();
let vault_a = if vault1_data.definition_id == pool_def_data.definition_token_a_id { let vault_a = if vault1.account_id == pool_def_data.vault_a_addr {
vault1.clone() vault1.clone()
} else if vault2_data.definition_id == pool_def_data.definition_token_a_id { } else if vault2.account_id == pool_def_data.vault_a_addr {
vault2.clone() vault2.clone()
} else { } else {
panic!("Vault A was not provided"); panic!("Vault A was not provided");
}; };
let vault_b = if vault1_data.definition_id == pool_def_data.definition_token_b_id { let vault_b = if vault1.account_id == pool_def_data.vault_b_addr {
vault1.clone() vault1.clone()
} else if vault2_data.definition_id == pool_def_data.definition_token_b_id { } else if vault2.account_id == pool_def_data.vault_b_addr {
vault2.clone() vault2.clone()
} else { } else {
panic!("Vault B was not provided"); panic!("Vault B was not provided");
@ -485,29 +483,24 @@ fn add_liquidity(pre_states: &[AccountWithMetadata],
let user_b = &pre_states[5]; let user_b = &pre_states[5];
let user_lp = &pre_states[6]; let user_lp = &pre_states[6];
let mut vault_a = AccountWithMetadata::default(); // Verify vaults are in fact vaults
let mut vault_b = AccountWithMetadata::default();
let pool_def_data = PoolDefinition::parse(&pool.account.data).unwrap(); let pool_def_data = PoolDefinition::parse(&pool.account.data).unwrap();
let vault1_data = TokenHolding::parse(&vault1.account.data).unwrap(); let vault_a = if vault1.account_id == pool_def_data.vault_a_addr {
let vault2_data = TokenHolding::parse(&vault2.account.data).unwrap(); vault1.clone()
} else if vault2.account_id == pool_def_data.vault_a_addr {
if vault1_data.definition_id == pool_def_data.definition_token_a_id { vault2.clone()
vault_a = vault1.clone(); } else {
} else if vault2_data.definition_id == pool_def_data.definition_token_a_id { panic!("Vault A was not provided");
vault_a = vault2.clone(); };
} else {
panic!("Vault A was not provided");
}
if vault1_data.definition_id == pool_def_data.definition_token_b_id { let vault_b = if vault1.account_id == pool_def_data.vault_b_addr {
vault_b = vault1.clone(); vault1.clone()
} else if vault2_data.definition_id == pool_def_data.definition_token_b_id { } else if vault2.account_id == pool_def_data.vault_b_addr {
vault_b = vault2.clone(); vault2.clone()
} else { } else {
panic!("Vault B was not provided"); panic!("Vault B was not provided");
} };
if max_balance_in.len() != 2 { if max_balance_in.len() != 2 {
panic!("Invalid number of input balances"); panic!("Invalid number of input balances");
@ -647,29 +640,25 @@ fn remove_liquidity(pre_states: &[AccountWithMetadata]) -> (Vec<Account>, Vec<Ch
let user_a = &pre_states[4]; let user_a = &pre_states[4];
let user_b = &pre_states[5]; let user_b = &pre_states[5];
let user_lp = &pre_states[6]; let user_lp = &pre_states[6];
let mut vault_a = AccountWithMetadata::default();
let mut vault_b = AccountWithMetadata::default();
// Verify vaults are in fact vaults
let pool_def_data = PoolDefinition::parse(&pool.account.data).unwrap(); let pool_def_data = PoolDefinition::parse(&pool.account.data).unwrap();
let vault1_data = TokenHolding::parse(&vault1.account.data).unwrap();
let vault2_data = TokenHolding::parse(&vault2.account.data).unwrap();
if vault1_data.definition_id == pool_def_data.definition_token_a_id { let vault_a = if vault1.account_id == pool_def_data.vault_a_addr {
vault_a = vault1.clone(); vault1.clone()
} else if vault2_data.definition_id == pool_def_data.definition_token_a_id { } else if vault2.account_id == pool_def_data.vault_a_addr {
vault_a = vault2.clone(); vault2.clone()
} else { } else {
panic!("Vault A was not provided"); panic!("Vault A was not provided");
} };
if vault1_data.definition_id == pool_def_data.definition_token_b_id { let vault_b = if vault1.account_id == pool_def_data.vault_b_addr {
vault_b = vault1.clone(); vault1.clone()
} else if vault2_data.definition_id == pool_def_data.definition_token_b_id { } else if vault2.account_id == pool_def_data.vault_b_addr {
vault_b = vault2.clone(); vault2.clone()
} else { } else {
panic!("Vault B was not provided"); panic!("Vault B was not provided");
} };
// 2. Determine deposit amounts // 2. Determine deposit amounts
let user_lp_amt = TokenHolding::parse(&user_lp.account.data).unwrap().balance; let user_lp_amt = TokenHolding::parse(&user_lp.account.data).unwrap().balance;
@ -1545,7 +1534,7 @@ mod tests {
vault1.data = TokenHolding::into_data( vault1.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
definition_id:definition_token_b_id.clone(), definition_id:definition_token_a_id.clone(),
balance: 15u128 } balance: 15u128 }
); );
@ -1632,7 +1621,7 @@ mod tests {
vault2.data = TokenHolding::into_data( vault2.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
definition_id:definition_token_a_id.clone(), definition_id:definition_token_b_id.clone(),
balance: 15u128 } balance: 15u128 }
); );
@ -1782,12 +1771,12 @@ mod tests {
let vault_a = AccountWithMetadata { let vault_a = AccountWithMetadata {
account: vault_a.clone(), account: vault_a.clone(),
is_authorized: true, is_authorized: true,
account_id: AccountId::new([2; 32])}; account_id: vault_a_addr.clone(),};
let vault_b = AccountWithMetadata { let vault_b = AccountWithMetadata {
account: vault_b.clone(), account: vault_b.clone(),
is_authorized: true, is_authorized: true,
account_id: AccountId::new([3; 32])}; account_id: vault_b_addr.clone()};
let pool_lp = AccountWithMetadata { let pool_lp = AccountWithMetadata {
account: pool_lp.clone(), account: pool_lp.clone(),
@ -2104,7 +2093,7 @@ mod tests {
vault1.data = TokenHolding::into_data( vault1.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
definition_id:definition_token_b_id.clone(), definition_id:definition_token_a_id.clone(),
balance: 15u128 } balance: 15u128 }
); );
@ -2193,7 +2182,7 @@ mod tests {
vault2.data = TokenHolding::into_data( vault2.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
definition_id:definition_token_a_id.clone(), definition_id:definition_token_b_id.clone(),
balance: 15u128 } balance: 15u128 }
); );
@ -3169,7 +3158,7 @@ mod tests {
let mut user_lp = Account::default(); let mut user_lp = Account::default();
let definition_token_a_id = AccountId::new([1;32]); let definition_token_a_id = AccountId::new([1;32]);
let definition_token_b_id = AccountId::new([2;32]); let definition_token_b_id = AccountId::new([2;32]);
vault_a.data = TokenHolding::into_data( vault_a.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
@ -3188,14 +3177,14 @@ mod tests {
1, 1, 1, 1, 1, 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 1, 1, 1, 1, 1, 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
]; ];
let vault_a_addr = AccountId::new([5;32]);
let vault_b_addr = AccountId::new([6;32]);
let liquidity_pool_id = AccountId::new([7;32]); let liquidity_pool_id = AccountId::new([7;32]);
let liquidity_pool_cap: u128 = 30u128; let liquidity_pool_cap: u128 = 30u128;
let reserve_a: u128 = 30; let reserve_a: u128 = 30;
let reserve_b: u128 = 20; let reserve_b: u128 = 20;
let user_lp_amt: u128 = 10; let user_lp_amt: u128 = 10;
let token_program_id: [u32;8] = [0; 8]; let token_program_id: [u32;8] = [0; 8];
let vault_a_addr = AccountId::new([7;32]);
let vault_b_addr = AccountId::new([9;32]);
pool.data = PoolDefinition::into_data( PoolDefinition { pool.data = PoolDefinition::into_data( PoolDefinition {
definition_token_a_id: definition_token_a_id.clone(), definition_token_a_id: definition_token_a_id.clone(),
@ -3246,12 +3235,12 @@ mod tests {
let vault_a = AccountWithMetadata { let vault_a = AccountWithMetadata {
account: vault_a.clone(), account: vault_a.clone(),
is_authorized: true, is_authorized: true,
account_id: AccountId::new([2; 32])}; account_id: vault_a_addr.clone()};
let vault_b = AccountWithMetadata { let vault_b = AccountWithMetadata {
account: vault_b.clone(), account: vault_b.clone(),
is_authorized: true, is_authorized: true,
account_id: AccountId::new([3; 32])}; account_id: vault_b_addr.clone()};
let pool_lp = AccountWithMetadata { let pool_lp = AccountWithMetadata {
account: pool_lp.clone(), account: pool_lp.clone(),
@ -3370,6 +3359,8 @@ mod tests {
let reserve_b: u128 = 20; let reserve_b: u128 = 20;
let user_lp_amt: u128 = 10; let user_lp_amt: u128 = 10;
let token_program_id: [u32;8] = [0; 8]; let token_program_id: [u32;8] = [0; 8];
let vault_a_addr = AccountId::new([2;32]);
let vault_b_addr = AccountId::new([3;32]);
pool.data = PoolDefinition::into_data( PoolDefinition { pool.data = PoolDefinition::into_data( PoolDefinition {
definition_token_a_id: definition_token_a_id.clone(), definition_token_a_id: definition_token_a_id.clone(),
@ -3420,12 +3411,12 @@ mod tests {
let vault_a = AccountWithMetadata { let vault_a = AccountWithMetadata {
account: vault_a.clone(), account: vault_a.clone(),
is_authorized: true, is_authorized: true,
account_id: AccountId::new([2; 32])}; account_id: vault_a_addr.clone()};
let vault_b = AccountWithMetadata { let vault_b = AccountWithMetadata {
account: vault_b.clone(), account: vault_b.clone(),
is_authorized: true, is_authorized: true,
account_id: AccountId::new([3; 32])}; account_id: vault_b_addr.clone()};
let pool_lp = AccountWithMetadata { let pool_lp = AccountWithMetadata {
account: pool_lp.clone(), account: pool_lp.clone(),
@ -3674,7 +3665,7 @@ mod tests {
vault1.data = TokenHolding::into_data( vault1.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
definition_id:definition_token_b_id.clone(), definition_id:definition_token_a_id.clone(),
balance: 15u128 } balance: 15u128 }
); );
@ -3746,7 +3737,7 @@ mod tests {
let definition_token_a_id = AccountId::new([1;32]); let definition_token_a_id = AccountId::new([1;32]);
let _definition_token_b_id = AccountId::new([2;32]); let definition_token_b_id = AccountId::new([2;32]);
vault1.data = TokenHolding::into_data( vault1.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
@ -3756,7 +3747,7 @@ mod tests {
vault2.data = TokenHolding::into_data( vault2.data = TokenHolding::into_data(
TokenHolding { account_type: TOKEN_HOLDING_TYPE, TokenHolding { account_type: TOKEN_HOLDING_TYPE,
definition_id:definition_token_a_id.clone(), definition_id:definition_token_b_id.clone(),
balance: 15u128 } balance: 15u128 }
); );
@ -4252,12 +4243,12 @@ mod tests {
is_authorized: true, is_authorized: true,
account_id: AccountId::new([0; 32])}, account_id: AccountId::new([0; 32])},
AccountWithMetadata { AccountWithMetadata {
account: vault_b.clone(),
is_authorized: true,
account_id: vault_b_addr.clone()},
AccountWithMetadata {
account: vault_a.clone(), account: vault_a.clone(),
is_authorized: true, is_authorized: true,
account_id: vault_a_addr.clone()},
AccountWithMetadata {
account: vault_b.clone(),
is_authorized: true,
account_id: vault_b_addr.clone()}, account_id: vault_b_addr.clone()},
AccountWithMetadata { AccountWithMetadata {
account: user_a.clone(), account: user_a.clone(),

View File

@ -2755,7 +2755,7 @@ pub mod tests {
let user_lp_post = state.get_account_by_address(&user_lp_holding_address); let user_lp_post = state.get_account_by_address(&user_lp_holding_address);
//TODO: this accounts for the initial balance for User_LP //TODO: this accounts for the initial balance for User_LP
let delta_lp : u128 = (init_balance_a*init_balance_a + temp_amt)/init_balance_a; let delta_lp : u128 = (init_balance_a*(init_balance_a + temp_amt))/init_balance_a;
let expected_pool = Account { let expected_pool = Account {
program_owner: Program::amm().id(), program_owner: Program::amm().id(),
@ -2767,7 +2767,7 @@ pub mod tests {
vault_a_addr: vault_a_address, vault_a_addr: vault_a_address,
vault_b_addr: vault_b_address, vault_b_addr: vault_b_address,
liquidity_pool_id: token_lp_definition_address, liquidity_pool_id: token_lp_definition_address,
liquidity_pool_cap: init_balance_a - delta_lp, //TODO: not 0 due to temp_amt; results in wrapping arithmetic. liquidity_pool_cap: init_balance_a - delta_lp,
reserve_a: 0, reserve_a: 0,
reserve_b: 0, reserve_b: 0,
token_program_id: Program::token().id(), token_program_id: Program::token().id(),
@ -2840,7 +2840,7 @@ pub mod tests {
assert!(user_a_post == expected_user_a); assert!(user_a_post == expected_user_a);
assert!(user_b_post == expected_user_b); assert!(user_b_post == expected_user_b);
assert!(user_lp_post == expected_user_lp); assert!(user_lp_post == expected_user_lp);
assert!(pool_post == expected_pool); assert!(pool_post.data == expected_pool.data);
} }
#[test] #[test]