diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 858a43c9..b327aaae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,10 @@ on: - "**.md" - "!.github/workflows/*.yml" +permissions: + contents: read + pull-requests: read + name: General jobs: @@ -19,7 +23,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - name: Install nightly toolchain for rustfmt run: rustup install nightly --profile minimal --component rustfmt @@ -32,7 +36,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - name: Install taplo-cli run: cargo install --locked taplo-cli @@ -45,7 +49,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - name: Install active toolchain run: rustup install @@ -61,7 +65,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - name: Install cargo-deny run: cargo install --locked cargo-deny @@ -77,7 +81,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - uses: ./.github/actions/install-system-deps @@ -106,7 +110,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - uses: ./.github/actions/install-system-deps @@ -134,7 +138,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - uses: ./.github/actions/install-system-deps @@ -164,7 +168,7 @@ jobs: # steps: # - uses: actions/checkout@v5 # with: - # ref: ${{ github.head_ref }} + # ref: ${{ github.event.pull_request.head.sha || github.head_ref }} # - uses: ./.github/actions/install-system-deps @@ -192,7 +196,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - uses: ./.github/actions/install-system-deps @@ -218,7 +222,7 @@ jobs: steps: - uses: actions/checkout@v5 with: - ref: ${{ github.head_ref }} + ref: ${{ github.event.pull_request.head.sha || github.head_ref }} - uses: ./.github/actions/install-risc0 diff --git a/Cargo.lock b/Cargo.lock index 5f0b29d0..d9c0bc76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8669,6 +8669,7 @@ dependencies = [ "async-stream", "ata_core", "base58", + "bip39", "clap", "common", "env_logger", diff --git a/artifacts/program_methods/amm.bin b/artifacts/program_methods/amm.bin index f71d77d0..c7831f0a 100644 Binary files a/artifacts/program_methods/amm.bin and b/artifacts/program_methods/amm.bin differ diff --git a/artifacts/program_methods/associated_token_account.bin b/artifacts/program_methods/associated_token_account.bin index 71e8e943..133aaa15 100644 Binary files a/artifacts/program_methods/associated_token_account.bin and b/artifacts/program_methods/associated_token_account.bin differ diff --git a/artifacts/program_methods/authenticated_transfer.bin b/artifacts/program_methods/authenticated_transfer.bin index aa64dac0..9b2d4882 100644 Binary files a/artifacts/program_methods/authenticated_transfer.bin and b/artifacts/program_methods/authenticated_transfer.bin differ diff --git a/artifacts/program_methods/clock.bin b/artifacts/program_methods/clock.bin index 60556fdc..89048e9f 100644 Binary files a/artifacts/program_methods/clock.bin and b/artifacts/program_methods/clock.bin differ diff --git a/artifacts/program_methods/pinata.bin b/artifacts/program_methods/pinata.bin index c3d70477..4f29bac8 100644 Binary files a/artifacts/program_methods/pinata.bin and b/artifacts/program_methods/pinata.bin differ diff --git a/artifacts/program_methods/pinata_token.bin b/artifacts/program_methods/pinata_token.bin index 22dbc5e8..4d73bc2f 100644 Binary files a/artifacts/program_methods/pinata_token.bin and b/artifacts/program_methods/pinata_token.bin differ diff --git a/artifacts/program_methods/privacy_preserving_circuit.bin b/artifacts/program_methods/privacy_preserving_circuit.bin index b2c09b79..ad37d810 100644 Binary files a/artifacts/program_methods/privacy_preserving_circuit.bin and b/artifacts/program_methods/privacy_preserving_circuit.bin differ diff --git a/artifacts/program_methods/token.bin b/artifacts/program_methods/token.bin index e24d7a1f..8f94c137 100644 Binary files a/artifacts/program_methods/token.bin and b/artifacts/program_methods/token.bin differ diff --git a/artifacts/test_program_methods/burner.bin b/artifacts/test_program_methods/burner.bin index a740bdb8..23fbcd88 100644 Binary files a/artifacts/test_program_methods/burner.bin and b/artifacts/test_program_methods/burner.bin differ diff --git a/artifacts/test_program_methods/chain_caller.bin b/artifacts/test_program_methods/chain_caller.bin index 112ca113..b871eaf5 100644 Binary files a/artifacts/test_program_methods/chain_caller.bin and b/artifacts/test_program_methods/chain_caller.bin differ diff --git a/artifacts/test_program_methods/changer_claimer.bin b/artifacts/test_program_methods/changer_claimer.bin index a130510b..22ce654e 100644 Binary files a/artifacts/test_program_methods/changer_claimer.bin and b/artifacts/test_program_methods/changer_claimer.bin differ diff --git a/artifacts/test_program_methods/claimer.bin b/artifacts/test_program_methods/claimer.bin index 41a5cb3b..9a99c27f 100644 Binary files a/artifacts/test_program_methods/claimer.bin and b/artifacts/test_program_methods/claimer.bin differ diff --git a/artifacts/test_program_methods/clock_chain_caller.bin b/artifacts/test_program_methods/clock_chain_caller.bin index 57cbb38d..993f9f12 100644 Binary files a/artifacts/test_program_methods/clock_chain_caller.bin and b/artifacts/test_program_methods/clock_chain_caller.bin differ diff --git a/artifacts/test_program_methods/data_changer.bin b/artifacts/test_program_methods/data_changer.bin index 3dddebe1..014cf755 100644 Binary files a/artifacts/test_program_methods/data_changer.bin and b/artifacts/test_program_methods/data_changer.bin differ diff --git a/artifacts/test_program_methods/extra_output.bin b/artifacts/test_program_methods/extra_output.bin index 1d682ec3..06a84868 100644 Binary files a/artifacts/test_program_methods/extra_output.bin and b/artifacts/test_program_methods/extra_output.bin differ diff --git a/artifacts/test_program_methods/malicious_authorization_changer.bin b/artifacts/test_program_methods/malicious_authorization_changer.bin index c68496ab..231cdbf4 100644 Binary files a/artifacts/test_program_methods/malicious_authorization_changer.bin and b/artifacts/test_program_methods/malicious_authorization_changer.bin differ diff --git a/artifacts/test_program_methods/minter.bin b/artifacts/test_program_methods/minter.bin index ffd29461..a08633c1 100644 Binary files a/artifacts/test_program_methods/minter.bin and b/artifacts/test_program_methods/minter.bin differ diff --git a/artifacts/test_program_methods/missing_output.bin b/artifacts/test_program_methods/missing_output.bin index a2bbecd8..1afe8ff1 100644 Binary files a/artifacts/test_program_methods/missing_output.bin and b/artifacts/test_program_methods/missing_output.bin differ diff --git a/artifacts/test_program_methods/modified_transfer.bin b/artifacts/test_program_methods/modified_transfer.bin index b44b1233..86f26d9a 100644 Binary files a/artifacts/test_program_methods/modified_transfer.bin and b/artifacts/test_program_methods/modified_transfer.bin differ diff --git a/artifacts/test_program_methods/nonce_changer.bin b/artifacts/test_program_methods/nonce_changer.bin index e006fc75..b855919f 100644 Binary files a/artifacts/test_program_methods/nonce_changer.bin and b/artifacts/test_program_methods/nonce_changer.bin differ diff --git a/artifacts/test_program_methods/noop.bin b/artifacts/test_program_methods/noop.bin index da811f60..c1518d10 100644 Binary files a/artifacts/test_program_methods/noop.bin and b/artifacts/test_program_methods/noop.bin differ diff --git a/artifacts/test_program_methods/program_owner_changer.bin b/artifacts/test_program_methods/program_owner_changer.bin index 3963873e..a7447878 100644 Binary files a/artifacts/test_program_methods/program_owner_changer.bin and b/artifacts/test_program_methods/program_owner_changer.bin differ diff --git a/artifacts/test_program_methods/simple_balance_transfer.bin b/artifacts/test_program_methods/simple_balance_transfer.bin index 08db47f0..92fce657 100644 Binary files a/artifacts/test_program_methods/simple_balance_transfer.bin and b/artifacts/test_program_methods/simple_balance_transfer.bin differ diff --git a/artifacts/test_program_methods/validity_window.bin b/artifacts/test_program_methods/validity_window.bin index ceb5ae74..3c8e2955 100644 Binary files a/artifacts/test_program_methods/validity_window.bin and b/artifacts/test_program_methods/validity_window.bin differ diff --git a/artifacts/test_program_methods/validity_window_chain_caller.bin b/artifacts/test_program_methods/validity_window_chain_caller.bin index a7661f03..1fdf286e 100644 Binary files a/artifacts/test_program_methods/validity_window_chain_caller.bin and b/artifacts/test_program_methods/validity_window_chain_caller.bin differ diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world.rs b/examples/program_deployment/methods/guest/src/bin/hello_world.rs index 810e83f3..ea2edd95 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world.rs @@ -19,6 +19,7 @@ fn main() { // Read inputs let ( ProgramInput { + self_program_id, pre_states, instruction: greeting, }, @@ -50,5 +51,11 @@ fn main() { // with the NSSA program rules. // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(instruction_data, vec![pre_state], vec![post_state]).write(); + ProgramOutput::new( + self_program_id, + instruction_data, + vec![pre_state], + vec![post_state], + ) + .write(); } diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs b/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs index 62908870..3f369fa7 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs @@ -19,6 +19,7 @@ fn main() { // Read inputs let ( ProgramInput { + self_program_id, pre_states, instruction: greeting, }, @@ -57,5 +58,11 @@ fn main() { // with the NSSA program rules. // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(instruction_data, vec![pre_state], vec![post_state]).write(); + ProgramOutput::new( + self_program_id, + instruction_data, + vec![pre_state], + vec![post_state], + ) + .write(); } diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs b/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs index 7e29b5de..57a2190c 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs @@ -66,6 +66,7 @@ fn main() { // Read input accounts. let ( ProgramInput { + self_program_id, pre_states, instruction: (function_id, data), }, @@ -85,5 +86,5 @@ fn main() { // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(instruction_words, pre_states, post_states).write(); + ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); } diff --git a/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs b/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs index d2c04083..22098b7a 100644 --- a/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs +++ b/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs @@ -27,6 +27,7 @@ fn main() { // Read inputs let ( ProgramInput { + self_program_id, pre_states, instruction: (), }, @@ -55,7 +56,12 @@ fn main() { // Write the outputs. // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(instruction_data, vec![pre_state], vec![post_state]) - .with_chained_calls(vec![chained_call]) - .write(); + ProgramOutput::new( + self_program_id, + instruction_data, + vec![pre_state], + vec![post_state], + ) + .with_chained_calls(vec![chained_call]) + .write(); } diff --git a/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs b/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs index 564efc2b..2ae65ec7 100644 --- a/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs +++ b/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs @@ -33,6 +33,7 @@ fn main() { // Read inputs let ( ProgramInput { + self_program_id, pre_states, instruction: (), }, @@ -68,7 +69,12 @@ fn main() { // Write the outputs. // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(instruction_data, vec![pre_state], vec![post_state]) - .with_chained_calls(vec![chained_call]) - .write(); + ProgramOutput::new( + self_program_id, + instruction_data, + vec![pre_state], + vec![post_state], + ) + .with_chained_calls(vec![chained_call]) + .write(); } diff --git a/integration_tests/src/lib.rs b/integration_tests/src/lib.rs index 08e7cf9f..a4381acf 100644 --- a/integration_tests/src/lib.rs +++ b/integration_tests/src/lib.rs @@ -256,11 +256,11 @@ impl TestContext { let config_overrides = WalletConfigOverrides::default(); let wallet_password = "test_pass".to_owned(); - let wallet = WalletCore::new_init_storage( + let (wallet, _mnemonic) = WalletCore::new_init_storage( config_path, storage_path, Some(config_overrides), - wallet_password.clone(), + &wallet_password, ) .context("Failed to init wallet")?; wallet diff --git a/integration_tests/tests/amm.rs b/integration_tests/tests/amm.rs index 42aa5f3f..d9ecb831 100644 --- a/integration_tests/tests/amm.rs +++ b/integration_tests/tests/amm.rs @@ -223,7 +223,7 @@ async fn amm_public() -> Result<()> { // Make swap - let subcommand = AmmProgramAgnosticSubcommand::Swap { + let subcommand = AmmProgramAgnosticSubcommand::SwapExactInput { user_holding_a: format_public_account_id(recipient_account_id_1), user_holding_b: format_public_account_id(recipient_account_id_2), amount_in: 2, @@ -266,7 +266,7 @@ async fn amm_public() -> Result<()> { // Make swap - let subcommand = AmmProgramAgnosticSubcommand::Swap { + let subcommand = AmmProgramAgnosticSubcommand::SwapExactInput { user_holding_a: format_public_account_id(recipient_account_id_1), user_holding_b: format_public_account_id(recipient_account_id_2), amount_in: 2, diff --git a/integration_tests/tests/wallet_ffi.rs b/integration_tests/tests/wallet_ffi.rs index 6e6b190c..ac548280 100644 --- a/integration_tests/tests/wallet_ffi.rs +++ b/integration_tests/tests/wallet_ffi.rs @@ -24,7 +24,6 @@ use log::info; use nssa::{Account, AccountId, PrivateKey, PublicKey, program::Program}; use nssa_core::program::DEFAULT_PROGRAM_ID; use tempfile::tempdir; -use wallet::WalletCore; use wallet_ffi::{ FfiAccount, FfiAccountList, FfiBytes32, FfiPrivateAccountKeys, FfiPublicAccountKey, FfiTransferResult, WalletHandle, error, @@ -211,14 +210,6 @@ fn new_wallet_ffi_with_default_config(password: &str) -> Result<*mut WalletHandl }) } -fn new_wallet_rust_with_default_config(password: &str) -> Result { - let tempdir = tempdir()?; - let config_path = tempdir.path().join("wallet_config.json"); - let storage_path = tempdir.path().join("storage.json"); - - WalletCore::new_init_storage(config_path, storage_path, None, password.to_owned()) -} - fn load_existing_ffi_wallet(home: &Path) -> Result<*mut WalletHandle> { let config_path = home.join("wallet_config.json"); let storage_path = home.join("storage.json"); @@ -232,19 +223,8 @@ fn load_existing_ffi_wallet(home: &Path) -> Result<*mut WalletHandle> { fn wallet_ffi_create_public_accounts() -> Result<()> { let password = "password_for_tests"; let n_accounts = 10; - // First `n_accounts` public accounts created with Rust wallet - let new_public_account_ids_rust = { - let mut account_ids = Vec::new(); - let mut wallet_rust = new_wallet_rust_with_default_config(password)?; - for _ in 0..n_accounts { - let account_id = wallet_rust.create_new_account_public(None).0; - account_ids.push(*account_id.value()); - } - account_ids - }; - - // First `n_accounts` public accounts created with wallet FFI + // Create `n_accounts` public accounts with wallet FFI let new_public_account_ids_ffi = unsafe { let mut account_ids = Vec::new(); @@ -258,7 +238,20 @@ fn wallet_ffi_create_public_accounts() -> Result<()> { account_ids }; - assert_eq!(new_public_account_ids_ffi, new_public_account_ids_rust); + // All returned IDs must be unique and non-zero + assert_eq!(new_public_account_ids_ffi.len(), n_accounts); + let unique: HashSet<_> = new_public_account_ids_ffi.iter().collect(); + assert_eq!( + unique.len(), + n_accounts, + "Duplicate public account IDs returned" + ); + assert!( + new_public_account_ids_ffi + .iter() + .all(|id| *id != [0_u8; 32]), + "Zero account ID returned" + ); Ok(()) } @@ -267,19 +260,7 @@ fn wallet_ffi_create_public_accounts() -> Result<()> { fn wallet_ffi_create_private_accounts() -> Result<()> { let password = "password_for_tests"; let n_accounts = 10; - // First `n_accounts` private accounts created with Rust wallet - let new_private_account_ids_rust = { - let mut account_ids = Vec::new(); - - let mut wallet_rust = new_wallet_rust_with_default_config(password)?; - for _ in 0..n_accounts { - let account_id = wallet_rust.create_new_account_private(None).0; - account_ids.push(*account_id.value()); - } - account_ids - }; - - // First `n_accounts` private accounts created with wallet FFI + // Create `n_accounts` private accounts with wallet FFI let new_private_account_ids_ffi = unsafe { let mut account_ids = Vec::new(); @@ -293,7 +274,20 @@ fn wallet_ffi_create_private_accounts() -> Result<()> { account_ids }; - assert_eq!(new_private_account_ids_ffi, new_private_account_ids_rust); + // All returned IDs must be unique and non-zero + assert_eq!(new_private_account_ids_ffi.len(), n_accounts); + let unique: HashSet<_> = new_private_account_ids_ffi.iter().collect(); + assert_eq!( + unique.len(), + n_accounts, + "Duplicate private account IDs returned" + ); + assert!( + new_private_account_ids_ffi + .iter() + .all(|id| *id != [0_u8; 32]), + "Zero account ID returned" + ); Ok(()) } @@ -349,28 +343,23 @@ fn wallet_ffi_save_and_load_persistent_storage() -> Result<()> { fn test_wallet_ffi_list_accounts() -> Result<()> { let password = "password_for_tests"; - // Create the wallet FFI - let wallet_ffi_handle = unsafe { + // Create the wallet FFI and track which account IDs were created as public/private + let (wallet_ffi_handle, created_public_ids, created_private_ids) = unsafe { let handle = new_wallet_ffi_with_default_config(password)?; - // Create 5 public accounts and 5 private accounts + let mut public_ids: Vec<[u8; 32]> = Vec::new(); + let mut private_ids: Vec<[u8; 32]> = Vec::new(); + + // Create 5 public accounts and 5 private accounts, recording their IDs for _ in 0..5 { let mut out_account_id = FfiBytes32::from_bytes([0; 32]); wallet_ffi_create_account_public(handle, &raw mut out_account_id); + public_ids.push(out_account_id.data); + wallet_ffi_create_account_private(handle, &raw mut out_account_id); + private_ids.push(out_account_id.data); } - handle - }; - - // Create the wallet Rust - let wallet_rust = { - let mut wallet = new_wallet_rust_with_default_config(password)?; - // Create 5 public accounts and 5 private accounts - for _ in 0..5 { - wallet.create_new_account_public(None); - wallet.create_new_account_private(None); - } - wallet + (handle, public_ids, private_ids) }; // Get the account list with FFI method @@ -380,15 +369,6 @@ fn test_wallet_ffi_list_accounts() -> Result<()> { out_list }; - let wallet_rust_account_ids = wallet_rust - .storage() - .user_data - .account_ids() - .collect::>(); - - // Assert same number of elements between Rust and FFI result - assert_eq!(wallet_rust_account_ids.len(), wallet_ffi_account_list.count); - let wallet_ffi_account_list_slice = unsafe { core::slice::from_raw_parts( wallet_ffi_account_list.entries, @@ -396,37 +376,38 @@ fn test_wallet_ffi_list_accounts() -> Result<()> { ) }; - // Assert same account ids between Rust and FFI result - assert_eq!( - wallet_rust_account_ids - .iter() - .map(nssa::AccountId::value) - .collect::>(), - wallet_ffi_account_list_slice - .iter() - .map(|entry| &entry.account_id.data) - .collect::>() - ); + // All created accounts must appear in the list + let listed_public_ids: HashSet<[u8; 32]> = wallet_ffi_account_list_slice + .iter() + .filter(|e| e.is_public) + .map(|e| e.account_id.data) + .collect(); + let listed_private_ids: HashSet<[u8; 32]> = wallet_ffi_account_list_slice + .iter() + .filter(|e| !e.is_public) + .map(|e| e.account_id.data) + .collect(); - // Assert `is_pub` flag is correct in the FFI result - for entry in wallet_ffi_account_list_slice { - let account_id = AccountId::new(entry.account_id.data); - let is_pub_default_in_rust_wallet = wallet_rust - .storage() - .user_data - .default_pub_account_signing_keys - .contains_key(&account_id); - let is_pub_key_tree_wallet_rust = wallet_rust - .storage() - .user_data - .public_key_tree - .account_id_map - .contains_key(&account_id); - - let is_public_in_rust_wallet = is_pub_default_in_rust_wallet || is_pub_key_tree_wallet_rust; - - assert_eq!(entry.is_public, is_public_in_rust_wallet); + for id in &created_public_ids { + assert!( + listed_public_ids.contains(id), + "Created public account not found in list with is_public=true" + ); } + for id in &created_private_ids { + assert!( + listed_private_ids.contains(id), + "Created private account not found in list with is_public=false" + ); + } + + // Total listed accounts must be at least the number we created + assert!( + wallet_ffi_account_list.count >= created_public_ids.len() + created_private_ids.len(), + "Listed account count ({}) is less than the number of created accounts ({})", + wallet_ffi_account_list.count, + created_public_ids.len() + created_private_ids.len() + ); unsafe { wallet_ffi_free_account_list(&raw mut wallet_ffi_account_list); diff --git a/key_protocol/src/key_management/mod.rs b/key_protocol/src/key_management/mod.rs index dcdaff45..c038c415 100644 --- a/key_protocol/src/key_management/mod.rs +++ b/key_protocol/src/key_management/mod.rs @@ -42,10 +42,10 @@ impl KeyChain { } #[must_use] - pub fn new_mnemonic(passphrase: String) -> Self { + pub fn new_mnemonic(passphrase: &str) -> (Self, bip39::Mnemonic) { // Currently dropping SeedHolder at the end of initialization. // Not entirely sure if we need it in the future. - let seed_holder = SeedHolder::new_mnemonic(passphrase); + let (seed_holder, mnemonic) = SeedHolder::new_mnemonic(passphrase); let secret_spending_key = seed_holder.produce_top_secret_key_holder(); let private_key_holder = secret_spending_key.produce_private_key_holder(None); @@ -53,12 +53,15 @@ impl KeyChain { let nullifier_public_key = private_key_holder.generate_nullifier_public_key(); let viewing_public_key = private_key_holder.generate_viewing_public_key(); - Self { - secret_spending_key, - private_key_holder, - nullifier_public_key, - viewing_public_key, - } + ( + Self { + secret_spending_key, + private_key_holder, + nullifier_public_key, + viewing_public_key, + }, + mnemonic, + ) } #[must_use] diff --git a/key_protocol/src/key_management/secret_holders.rs b/key_protocol/src/key_management/secret_holders.rs index 02890631..9804ba39 100644 --- a/key_protocol/src/key_management/secret_holders.rs +++ b/key_protocol/src/key_management/secret_holders.rs @@ -8,8 +8,6 @@ use rand::{RngCore as _, rngs::OsRng}; use serde::{Deserialize, Serialize}; use sha2::{Digest as _, digest::FixedOutput as _}; -const NSSA_ENTROPY_BYTES: [u8; 32] = [0; 32]; - /// Seed holder. Non-clonable to ensure that different holders use different seeds. /// Produces `TopSecretKeyHolder` objects. #[derive(Debug)] @@ -48,9 +46,24 @@ impl SeedHolder { } #[must_use] - pub fn new_mnemonic(passphrase: String) -> Self { - let mnemonic = Mnemonic::from_entropy(&NSSA_ENTROPY_BYTES) - .expect("Enthropy must be a multiple of 32 bytes"); + pub fn new_mnemonic(passphrase: &str) -> (Self, Mnemonic) { + let mut entropy_bytes: [u8; 32] = [0; 32]; + OsRng.fill_bytes(&mut entropy_bytes); + + let mnemonic = + Mnemonic::from_entropy(&entropy_bytes).expect("Entropy must be a multiple of 32 bytes"); + let seed_wide = mnemonic.to_seed(passphrase); + + ( + Self { + seed: seed_wide.to_vec(), + }, + mnemonic, + ) + } + + #[must_use] + pub fn from_mnemonic(mnemonic: &Mnemonic, passphrase: &str) -> Self { let seed_wide = mnemonic.to_seed(passphrase); Self { @@ -175,12 +188,63 @@ mod tests { } #[test] - fn two_seeds_generated_same_from_same_mnemonic() { - let mnemonic = "test_pass"; + fn two_seeds_recovered_same_from_same_mnemonic() { + let passphrase = "test_pass"; - let seed_holder1 = SeedHolder::new_mnemonic(mnemonic.to_owned()); - let seed_holder2 = SeedHolder::new_mnemonic(mnemonic.to_owned()); + // Generate a mnemonic with random entropy + let (original_seed_holder, mnemonic) = SeedHolder::new_mnemonic(passphrase); - assert_eq!(seed_holder1.seed, seed_holder2.seed); + // Recover from the same mnemonic + let recovered_seed_holder = SeedHolder::from_mnemonic(&mnemonic, passphrase); + + assert_eq!(original_seed_holder.seed, recovered_seed_holder.seed); + } + + #[test] + fn new_mnemonic_generates_different_seeds_each_time() { + let (seed_holder1, mnemonic1) = SeedHolder::new_mnemonic(""); + let (seed_holder2, mnemonic2) = SeedHolder::new_mnemonic(""); + + // Different entropy should produce different mnemonics and seeds + assert_ne!(mnemonic1.to_string(), mnemonic2.to_string()); + assert_ne!(seed_holder1.seed, seed_holder2.seed); + } + + #[test] + fn new_mnemonic_generates_24_word_phrase() { + let (_seed_holder, mnemonic) = SeedHolder::new_mnemonic(""); + + // 256 bits of entropy produces a 24-word mnemonic + let word_count = mnemonic.to_string().split_whitespace().count(); + assert_eq!(word_count, 24); + } + + #[test] + fn new_mnemonic_produces_valid_seed_length() { + let (seed_holder, _mnemonic) = SeedHolder::new_mnemonic(""); + + assert_eq!(seed_holder.seed.len(), 64); + } + + #[test] + fn different_passphrases_produce_different_seeds() { + let (_seed_holder, mnemonic) = SeedHolder::new_mnemonic(""); + + let seed_with_pass_a = SeedHolder::from_mnemonic(&mnemonic, "password_a"); + let seed_with_pass_b = SeedHolder::from_mnemonic(&mnemonic, "password_b"); + + // Same mnemonic but different passphrases should produce different seeds + assert_ne!(seed_with_pass_a.seed, seed_with_pass_b.seed); + } + + #[test] + fn empty_passphrase_is_deterministic() { + let (_seed_holder, mnemonic) = SeedHolder::new_mnemonic(""); + + let seed1 = SeedHolder::from_mnemonic(&mnemonic, ""); + let seed2 = SeedHolder::from_mnemonic(&mnemonic, ""); + + // Same mnemonic and passphrase should always produce the same seed + assert_eq!(seed1.seed, seed2.seed); } } diff --git a/key_protocol/src/key_protocol_core/mod.rs b/key_protocol/src/key_protocol_core/mod.rs index 8232d9f4..8186865f 100644 --- a/key_protocol/src/key_protocol_core/mod.rs +++ b/key_protocol/src/key_protocol_core/mod.rs @@ -181,11 +181,12 @@ impl NSSAUserData { impl Default for NSSAUserData { fn default() -> Self { + let (seed_holder, _mnemonic) = SeedHolder::new_mnemonic(""); Self::new_with_accounts( BTreeMap::new(), BTreeMap::new(), - KeyTreePublic::new(&SeedHolder::new_mnemonic("default".to_owned())), - KeyTreePrivate::new(&SeedHolder::new_mnemonic("default".to_owned())), + KeyTreePublic::new(&seed_holder), + KeyTreePrivate::new(&seed_holder), ) .unwrap() } diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 673e09b3..057c8238 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -16,6 +16,7 @@ pub const MAX_NUMBER_CHAINED_CALLS: usize = 10; pub type ProgramId = [u32; 8]; pub type InstructionData = Vec; pub struct ProgramInput { + pub self_program_id: ProgramId, pub pre_states: Vec, pub instruction: T, } @@ -281,6 +282,8 @@ pub struct InvalidWindow; #[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))] #[must_use = "ProgramOutput does nothing unless written"] pub struct ProgramOutput { + /// The program ID of the program that produced this output. + pub self_program_id: ProgramId, /// The instruction data the program received to produce this output. pub instruction_data: InstructionData, /// The account pre states the program received to produce this output. @@ -297,11 +300,13 @@ pub struct ProgramOutput { impl ProgramOutput { pub const fn new( + self_program_id: ProgramId, instruction_data: InstructionData, pre_states: Vec, post_states: Vec, ) -> Self { Self { + self_program_id, instruction_data, pre_states, post_states, @@ -415,11 +420,13 @@ pub fn compute_authorized_pdas( /// Reads the NSSA inputs from the guest environment. #[must_use] pub fn read_nssa_inputs() -> (ProgramInput, InstructionData) { + let self_program_id: ProgramId = env::read(); let pre_states: Vec = env::read(); let instruction_words: InstructionData = env::read(); let instruction = T::deserialize(&mut Deserializer::new(instruction_words.as_ref())).unwrap(); ( ProgramInput { + self_program_id, pre_states, instruction, }, @@ -620,7 +627,7 @@ mod tests { #[test] fn program_output_try_with_block_validity_window_range() { - let output = ProgramOutput::new(vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) .try_with_block_validity_window(10_u64..100) .unwrap(); assert_eq!(output.block_validity_window.start(), Some(10)); @@ -629,24 +636,24 @@ mod tests { #[test] fn program_output_with_block_validity_window_range_from() { - let output = - ProgramOutput::new(vec![], vec![], vec![]).with_block_validity_window(10_u64..); + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + .with_block_validity_window(10_u64..); assert_eq!(output.block_validity_window.start(), Some(10)); assert_eq!(output.block_validity_window.end(), None); } #[test] fn program_output_with_block_validity_window_range_to() { - let output = - ProgramOutput::new(vec![], vec![], vec![]).with_block_validity_window(..100_u64); + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + .with_block_validity_window(..100_u64); assert_eq!(output.block_validity_window.start(), None); assert_eq!(output.block_validity_window.end(), Some(100)); } #[test] fn program_output_try_with_block_validity_window_empty_range_fails() { - let result = - ProgramOutput::new(vec![], vec![], vec![]).try_with_block_validity_window(5_u64..5); + let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + .try_with_block_validity_window(5_u64..5); assert!(result.is_err()); } diff --git a/nssa/src/privacy_preserving_transaction/circuit.rs b/nssa/src/privacy_preserving_transaction/circuit.rs index 0ae7eaac..48c59ce7 100644 --- a/nssa/src/privacy_preserving_transaction/circuit.rs +++ b/nssa/src/privacy_preserving_transaction/circuit.rs @@ -158,7 +158,7 @@ fn execute_and_prove_program( ) -> Result { // Write inputs to the program let mut env_builder = ExecutorEnv::builder(); - Program::write_inputs(pre_states, instruction_data, &mut env_builder)?; + Program::write_inputs(program.id(), pre_states, instruction_data, &mut env_builder)?; let env = env_builder.build().unwrap(); // Prove the program diff --git a/nssa/src/program.rs b/nssa/src/program.rs index b8fb2595..a7e376ee 100644 --- a/nssa/src/program.rs +++ b/nssa/src/program.rs @@ -59,7 +59,7 @@ impl Program { // Write inputs to the program let mut env_builder = ExecutorEnv::builder(); env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION)); - Self::write_inputs(pre_states, instruction_data, &mut env_builder)?; + Self::write_inputs(self.id, pre_states, instruction_data, &mut env_builder)?; let env = env_builder.build().unwrap(); // Execute the program (without proving) @@ -79,13 +79,20 @@ impl Program { /// Writes inputs to `env_builder` in the order expected by the programs. pub(crate) fn write_inputs( + program_id: ProgramId, pre_states: &[AccountWithMetadata], instruction_data: &[u32], env_builder: &mut ExecutorEnvBuilder, ) -> Result<(), NssaError> { + env_builder + .write(&program_id) + .map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?; let pre_states = pre_states.to_vec(); env_builder - .write(&(pre_states, instruction_data)) + .write(&pre_states) + .map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?; + env_builder + .write(&instruction_data) .map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?; Ok(()) } diff --git a/nssa/src/validated_state_diff.rs b/nssa/src/validated_state_diff.rs index 6598f711..e4e0cacc 100644 --- a/nssa/src/validated_state_diff.rs +++ b/nssa/src/validated_state_diff.rs @@ -151,6 +151,12 @@ impl ValidatedStateDiff { ); } + // Verify that the program output's self_program_id matches the expected program ID. + ensure!( + program_output.self_program_id == chained_call.program_id, + NssaError::InvalidProgramBehavior + ); + // Verify execution corresponds to a well-behaved program. // See the # Programs section for the definition of the `validate_execution` method. ensure!( diff --git a/program_methods/guest/src/bin/amm.rs b/program_methods/guest/src/bin/amm.rs index 748630d9..59c89742 100644 --- a/program_methods/guest/src/bin/amm.rs +++ b/program_methods/guest/src/bin/amm.rs @@ -14,6 +14,7 @@ use nssa_core::program::{ProgramInput, ProgramOutput, read_nssa_inputs}; fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction, }, @@ -112,15 +113,15 @@ fn main() { min_amount_to_remove_token_b, ) } - Instruction::Swap { + Instruction::SwapExactInput { swap_amount_in, min_amount_out, token_definition_id_in, } => { let [pool, vault_a, vault_b, user_holding_a, user_holding_b] = pre_states .try_into() - .expect("Transfer instruction requires exactly five accounts"); - amm_program::swap::swap( + .expect("SwapExactInput instruction requires exactly five accounts"); + amm_program::swap::swap_exact_input( pool, vault_a, vault_b, @@ -131,9 +132,33 @@ fn main() { token_definition_id_in, ) } + Instruction::SwapExactOutput { + exact_amount_out, + max_amount_in, + token_definition_id_in, + } => { + let [pool, vault_a, vault_b, user_holding_a, user_holding_b] = pre_states + .try_into() + .expect("SwapExactOutput instruction requires exactly five accounts"); + amm_program::swap::swap_exact_output( + pool, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition_id_in, + ) + } }; - ProgramOutput::new(instruction_words, pre_states_clone, post_states) - .with_chained_calls(chained_calls) - .write(); + ProgramOutput::new( + self_program_id, + instruction_words, + pre_states_clone, + post_states, + ) + .with_chained_calls(chained_calls) + .write(); } diff --git a/program_methods/guest/src/bin/associated_token_account.rs b/program_methods/guest/src/bin/associated_token_account.rs index 55d5824b..42162ba2 100644 --- a/program_methods/guest/src/bin/associated_token_account.rs +++ b/program_methods/guest/src/bin/associated_token_account.rs @@ -4,6 +4,7 @@ use nssa_core::program::{ProgramInput, ProgramOutput, read_nssa_inputs}; fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction, }, @@ -56,7 +57,12 @@ fn main() { } }; - ProgramOutput::new(instruction_words, pre_states_clone, post_states) - .with_chained_calls(chained_calls) - .write(); + ProgramOutput::new( + self_program_id, + instruction_words, + pre_states_clone, + post_states, + ) + .with_chained_calls(chained_calls) + .write(); } diff --git a/program_methods/guest/src/bin/authenticated_transfer.rs b/program_methods/guest/src/bin/authenticated_transfer.rs index 2fb0ea8b..d7c68e62 100644 --- a/program_methods/guest/src/bin/authenticated_transfer.rs +++ b/program_methods/guest/src/bin/authenticated_transfer.rs @@ -67,6 +67,7 @@ fn main() { // Read input accounts. let ( ProgramInput { + self_program_id, pre_states, instruction: balance_to_move, }, @@ -84,5 +85,5 @@ fn main() { _ => panic!("invalid params"), }; - ProgramOutput::new(instruction_words, pre_states, post_states).write(); + ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); } diff --git a/program_methods/guest/src/bin/clock.rs b/program_methods/guest/src/bin/clock.rs index 9e15cc8b..4cdc86dc 100644 --- a/program_methods/guest/src/bin/clock.rs +++ b/program_methods/guest/src/bin/clock.rs @@ -29,6 +29,7 @@ fn update_if_multiple( fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: timestamp, }, @@ -68,6 +69,7 @@ fn main() { let (pre_50, post_50) = update_if_multiple(pre_50, 50, current_block_id, updated_data); ProgramOutput::new( + self_program_id, instruction_words, vec![pre_01, pre_10, pre_50], vec![post_01, post_10, post_50], diff --git a/program_methods/guest/src/bin/pinata.rs b/program_methods/guest/src/bin/pinata.rs index 2f85f069..d6f35ae8 100644 --- a/program_methods/guest/src/bin/pinata.rs +++ b/program_methods/guest/src/bin/pinata.rs @@ -46,6 +46,7 @@ fn main() { // It is expected to receive only two accounts: [pinata_account, winner_account] let ( ProgramInput { + self_program_id, pre_states, instruction: solution, }, @@ -79,6 +80,7 @@ fn main() { .expect("Overflow when adding prize to winner"); ProgramOutput::new( + self_program_id, instruction_words, vec![pinata, winner], vec![ diff --git a/program_methods/guest/src/bin/pinata_token.rs b/program_methods/guest/src/bin/pinata_token.rs index 3dee05b7..5c31af45 100644 --- a/program_methods/guest/src/bin/pinata_token.rs +++ b/program_methods/guest/src/bin/pinata_token.rs @@ -52,6 +52,7 @@ fn main() { // winner_token_holding] let ( ProgramInput { + self_program_id, pre_states, instruction: solution, }, @@ -97,6 +98,7 @@ fn main() { .with_pda_seeds(vec![PdaSeed::new([0; 32])]); ProgramOutput::new( + self_program_id, instruction_words, vec![ pinata_definition, diff --git a/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/program_methods/guest/src/bin/privacy_preserving_circuit.rs index e53334f9..48d4b3b7 100644 --- a/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -107,6 +107,13 @@ impl ExecutionState { |_: Infallible| unreachable!("Infallible error is never constructed"), ); + // Verify that the program output's self_program_id matches the expected program ID. + // This ensures the proof commits to which program produced the output. + assert_eq!( + program_output.self_program_id, chained_call.program_id, + "Program output self_program_id does not match chained call program_id" + ); + // Check that the program is well behaved. // See the # Programs section for the definition of the `validate_execution` method. let execution_valid = validate_execution( diff --git a/program_methods/guest/src/bin/token.rs b/program_methods/guest/src/bin/token.rs index 421d43ef..2414a289 100644 --- a/program_methods/guest/src/bin/token.rs +++ b/program_methods/guest/src/bin/token.rs @@ -12,6 +12,7 @@ use token_program::core::Instruction; fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction, }, @@ -81,5 +82,11 @@ fn main() { } }; - ProgramOutput::new(instruction_words, pre_states_clone, post_states).write(); + ProgramOutput::new( + self_program_id, + instruction_words, + pre_states_clone, + post_states, + ) + .write(); } diff --git a/programs/amm/core/src/lib.rs b/programs/amm/core/src/lib.rs index 85efd00d..017f14ff 100644 --- a/programs/amm/core/src/lib.rs +++ b/programs/amm/core/src/lib.rs @@ -68,11 +68,27 @@ pub enum Instruction { /// - User Holding Account for Token A /// - User Holding Account for Token B Either User Holding Account for Token A or Token B is /// authorized. - Swap { + SwapExactInput { swap_amount_in: u128, min_amount_out: u128, token_definition_id_in: AccountId, }, + + /// Swap tokens specifying the exact desired output amount, + /// while maintaining the Pool constant product. + /// + /// Required accounts: + /// - AMM Pool (initialized) + /// - Vault Holding Account for Token A (initialized) + /// - Vault Holding Account for Token B (initialized) + /// - User Holding Account for Token A + /// - User Holding Account for Token B Either User Holding Account for Token A or Token B is + /// authorized. + SwapExactOutput { + exact_amount_out: u128, + max_amount_in: u128, + token_definition_id_in: AccountId, + }, } #[derive(Clone, Default, Serialize, Deserialize, BorshSerialize, BorshDeserialize)] diff --git a/programs/amm/src/swap.rs b/programs/amm/src/swap.rs index cb64f5eb..22f3792a 100644 --- a/programs/amm/src/swap.rs +++ b/programs/amm/src/swap.rs @@ -4,21 +4,14 @@ use nssa_core::{ program::{AccountPostState, ChainedCall}, }; -#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] -#[must_use] -pub fn swap( - pool: AccountWithMetadata, - vault_a: AccountWithMetadata, - vault_b: AccountWithMetadata, - user_holding_a: AccountWithMetadata, - user_holding_b: AccountWithMetadata, - swap_amount_in: u128, - min_amount_out: u128, - token_in_id: AccountId, -) -> (Vec, Vec) { - // Verify vaults are in fact vaults +/// Validates swap setup: checks pool is active, vaults match, and reserves are sufficient. +fn validate_swap_setup( + pool: &AccountWithMetadata, + vault_a: &AccountWithMetadata, + vault_b: &AccountWithMetadata, +) -> PoolDefinition { let pool_def_data = PoolDefinition::try_from(&pool.account.data) - .expect("Swap: AMM Program expects a valid Pool Definition Account"); + .expect("AMM Program expects a valid Pool Definition Account"); assert!(pool_def_data.active, "Pool is inactive"); assert_eq!( @@ -30,16 +23,14 @@ pub fn swap( "Vault B was not provided" ); - // fetch pool reserves - // validates reserves is at least the vaults' balances let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) - .expect("Swap: AMM Program expects a valid Token Holding Account for Vault A"); + .expect("AMM Program expects a valid Token Holding Account for Vault A"); let token_core::TokenHolding::Fungible { definition_id: _, balance: vault_a_balance, } = vault_a_token_holding else { - panic!("Swap: AMM Program expects a valid Fungible Token Holding Account for Vault A"); + panic!("AMM Program expects a valid Fungible Token Holding Account for Vault A"); }; assert!( @@ -48,13 +39,13 @@ pub fn swap( ); let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) - .expect("Swap: AMM Program expects a valid Token Holding Account for Vault B"); + .expect("AMM Program expects a valid Token Holding Account for Vault B"); let token_core::TokenHolding::Fungible { definition_id: _, balance: vault_b_balance, } = vault_b_token_holding else { - panic!("Swap: AMM Program expects a valid Fungible Token Holding Account for Vault B"); + panic!("AMM Program expects a valid Fungible Token Holding Account for Vault B"); }; assert!( @@ -62,6 +53,59 @@ pub fn swap( "Reserve for Token B exceeds vault balance" ); + pool_def_data +} + +/// Creates post-state and returns reserves after swap. +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[expect( + clippy::needless_pass_by_value, + reason = "consistent with codebase style" +)] +fn create_swap_post_states( + pool: AccountWithMetadata, + pool_def_data: PoolDefinition, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + deposit_a: u128, + withdraw_a: u128, + deposit_b: u128, + withdraw_b: u128, +) -> Vec { + let mut pool_post = pool.account; + let pool_post_definition = PoolDefinition { + reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, + reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, + ..pool_def_data + }; + + pool_post.data = Data::from(&pool_post_definition); + + vec![ + AccountPostState::new(pool_post), + AccountPostState::new(vault_a.account), + AccountPostState::new(vault_b.account), + AccountPostState::new(user_holding_a.account), + AccountPostState::new(user_holding_b.account), + ] +} + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[must_use] +pub fn swap_exact_input( + pool: AccountWithMetadata, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + swap_amount_in: u128, + min_amount_out: u128, + token_in_id: AccountId, +) -> (Vec, Vec) { + let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = if token_in_id == pool_def_data.definition_token_a_id { let (chained_calls, deposit_a, withdraw_b) = swap_logic( @@ -95,23 +139,18 @@ pub fn swap( panic!("AccountId is not a token type for the pool"); }; - // Update pool account - let mut pool_post = pool.account; - let pool_post_definition = PoolDefinition { - reserve_a: pool_def_data.reserve_a + deposit_a - withdraw_a, - reserve_b: pool_def_data.reserve_b + deposit_b - withdraw_b, - ..pool_def_data - }; - - pool_post.data = Data::from(&pool_post_definition); - - let post_states = vec![ - AccountPostState::new(pool_post), - AccountPostState::new(vault_a.account), - AccountPostState::new(vault_b.account), - AccountPostState::new(user_holding_a.account), - AccountPostState::new(user_holding_b.account), - ]; + let post_states = create_swap_post_states( + pool, + pool_def_data, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + deposit_a, + withdraw_a, + deposit_b, + withdraw_b, + ); (post_states, chained_calls) } @@ -131,7 +170,9 @@ fn swap_logic( // Compute withdraw amount // Maintains pool constant product // k = pool_def_data.reserve_a * pool_def_data.reserve_b; - let withdraw_amount = (reserve_withdraw_vault_amount * swap_amount_in) + let withdraw_amount = reserve_withdraw_vault_amount + .checked_mul(swap_amount_in) + .expect("reserve * amount_in overflows u128") / (reserve_deposit_vault_amount + swap_amount_in); // Slippage check @@ -175,3 +216,135 @@ fn swap_logic( (chained_calls, swap_amount_in, withdraw_amount) } + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +#[must_use] +pub fn swap_exact_output( + pool: AccountWithMetadata, + vault_a: AccountWithMetadata, + vault_b: AccountWithMetadata, + user_holding_a: AccountWithMetadata, + user_holding_b: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + token_in_id: AccountId, +) -> (Vec, Vec) { + let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = + if token_in_id == pool_def_data.definition_token_a_id { + let (chained_calls, deposit_a, withdraw_b) = exact_output_swap_logic( + user_holding_a.clone(), + vault_a.clone(), + vault_b.clone(), + user_holding_b.clone(), + exact_amount_out, + max_amount_in, + pool_def_data.reserve_a, + pool_def_data.reserve_b, + pool.account_id, + ); + + (chained_calls, [deposit_a, 0], [0, withdraw_b]) + } else if token_in_id == pool_def_data.definition_token_b_id { + let (chained_calls, deposit_b, withdraw_a) = exact_output_swap_logic( + user_holding_b.clone(), + vault_b.clone(), + vault_a.clone(), + user_holding_a.clone(), + exact_amount_out, + max_amount_in, + pool_def_data.reserve_b, + pool_def_data.reserve_a, + pool.account_id, + ); + + (chained_calls, [0, withdraw_a], [deposit_b, 0]) + } else { + panic!("AccountId is not a token type for the pool"); + }; + + let post_states = create_swap_post_states( + pool, + pool_def_data, + vault_a, + vault_b, + user_holding_a, + user_holding_b, + deposit_a, + withdraw_a, + deposit_b, + withdraw_b, + ); + + (post_states, chained_calls) +} + +#[expect(clippy::too_many_arguments, reason = "TODO: Fix later")] +fn exact_output_swap_logic( + user_deposit: AccountWithMetadata, + vault_deposit: AccountWithMetadata, + vault_withdraw: AccountWithMetadata, + user_withdraw: AccountWithMetadata, + exact_amount_out: u128, + max_amount_in: u128, + reserve_deposit_vault_amount: u128, + reserve_withdraw_vault_amount: u128, + pool_id: AccountId, +) -> (Vec, u128, u128) { + // Guard: exact_amount_out must be nonzero + assert_ne!(exact_amount_out, 0, "Exact amount out must be nonzero"); + + // Guard: exact_amount_out must be less than reserve_withdraw_vault_amount + assert!( + exact_amount_out < reserve_withdraw_vault_amount, + "Exact amount out exceeds reserve" + ); + + // Compute deposit amount using ceiling division + // Formula: amount_in = ceil(reserve_in * exact_amount_out / (reserve_out - exact_amount_out)) + let deposit_amount = reserve_deposit_vault_amount + .checked_mul(exact_amount_out) + .expect("reserve * amount_out overflows u128") + .div_ceil(reserve_withdraw_vault_amount - exact_amount_out); + + // Slippage check + assert!( + deposit_amount <= max_amount_in, + "Required input exceeds maximum amount in" + ); + + let token_program_id = user_deposit.account.program_owner; + + let mut chained_calls = Vec::new(); + chained_calls.push(ChainedCall::new( + token_program_id, + vec![user_deposit, vault_deposit], + &token_core::Instruction::Transfer { + amount_to_transfer: deposit_amount, + }, + )); + + let mut vault_withdraw = vault_withdraw; + vault_withdraw.is_authorized = true; + + let pda_seed = compute_vault_pda_seed( + pool_id, + token_core::TokenHolding::try_from(&vault_withdraw.account.data) + .expect("Exact Output Swap Logic: AMM Program expects valid token data") + .definition_id(), + ); + + chained_calls.push( + ChainedCall::new( + token_program_id, + vec![vault_withdraw, user_withdraw], + &token_core::Instruction::Transfer { + amount_to_transfer: exact_amount_out, + }, + ) + .with_pda_seeds(vec![pda_seed]), + ); + + (chained_calls, deposit_amount, exact_amount_out) +} diff --git a/programs/amm/src/tests.rs b/programs/amm/src/tests.rs index 13ae7a89..43e20168 100644 --- a/programs/amm/src/tests.rs +++ b/programs/amm/src/tests.rs @@ -14,7 +14,10 @@ use nssa_core::{ use token_core::{TokenDefinition, TokenHolding}; use crate::{ - add::add_liquidity, new_definition::new_definition, remove::remove_liquidity, swap::swap, + add::add_liquidity, + new_definition::new_definition, + remove::remove_liquidity, + swap::{swap_exact_input, swap_exact_output}, }; const TOKEN_PROGRAM_ID: ProgramId = [15; 8]; @@ -153,6 +156,10 @@ impl BalanceForTests { 200 } + fn max_amount_in() -> u128 { + 166 + } + fn vault_a_add_successful() -> u128 { 1_400 } @@ -243,6 +250,74 @@ impl ChainedCallForTests { ) } + fn cc_swap_exact_output_token_a_test_1() -> ChainedCall { + let swap_amount: u128 = 498; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::vault_a_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + } + + fn cc_swap_exact_output_token_b_test_1() -> ChainedCall { + let swap_amount: u128 = 166; + + let mut vault_b_auth = AccountWithMetadataForTests::vault_b_init(); + vault_b_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_b_definition_id(), + )]) + } + + fn cc_swap_exact_output_token_a_test_2() -> ChainedCall { + let swap_amount: u128 = 285; + + let mut vault_a_auth = AccountWithMetadataForTests::vault_a_init(); + vault_a_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_a_auth, AccountWithMetadataForTests::user_holding_a()], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_a_definition_id(), + )]) + } + + fn cc_swap_exact_output_token_b_test_2() -> ChainedCall { + let swap_amount: u128 = 200; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::vault_b_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: swap_amount, + }, + ) + } + fn cc_add_token_a() -> ChainedCall { ChainedCall::new( TOKEN_PROGRAM_ID, @@ -829,6 +904,54 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_swap_exact_output_test_1() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: 1498_u128, + reserve_b: 334_u128, + fees: 0_u128, + active: true, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + + fn pool_definition_swap_exact_output_test_2() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: BalanceForTests::lp_supply_init(), + reserve_a: BalanceForTests::vault_a_swap_test_2(), + reserve_b: BalanceForTests::vault_b_swap_test_2(), + fees: 0_u128, + active: true, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_add_zero_lp() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -2400,7 +2523,7 @@ fn call_new_definition_chained_call_successful() { #[should_panic(expected = "AccountId is not a token type for the pool")] #[test] fn call_swap_incorrect_token_type() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_init(), @@ -2415,7 +2538,7 @@ fn call_swap_incorrect_token_type() { #[should_panic(expected = "Vault A was not provided")] #[test] fn call_swap_vault_a_omitted() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_with_wrong_id(), AccountWithMetadataForTests::vault_b_init(), @@ -2430,7 +2553,7 @@ fn call_swap_vault_a_omitted() { #[should_panic(expected = "Vault B was not provided")] #[test] fn call_swap_vault_b_omitted() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_with_wrong_id(), @@ -2445,7 +2568,7 @@ fn call_swap_vault_b_omitted() { #[should_panic(expected = "Reserve for Token A exceeds vault balance")] #[test] fn call_swap_reserves_vault_mismatch_1() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init_low(), AccountWithMetadataForTests::vault_b_init(), @@ -2460,7 +2583,7 @@ fn call_swap_reserves_vault_mismatch_1() { #[should_panic(expected = "Reserve for Token B exceeds vault balance")] #[test] fn call_swap_reserves_vault_mismatch_2() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_init_low(), @@ -2475,7 +2598,7 @@ fn call_swap_reserves_vault_mismatch_2() { #[should_panic(expected = "Pool is inactive")] #[test] fn call_swap_ianctive() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_inactive(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_init(), @@ -2490,7 +2613,7 @@ fn call_swap_ianctive() { #[should_panic(expected = "Withdraw amount is less than minimal amount out")] #[test] fn call_swap_below_min_out() { - let _post_states = swap( + let _post_states = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_init(), @@ -2504,7 +2627,7 @@ fn call_swap_below_min_out() { #[test] fn call_swap_chained_call_successful_1() { - let (post_states, chained_calls) = swap( + let (post_states, chained_calls) = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_init(), @@ -2536,7 +2659,7 @@ fn call_swap_chained_call_successful_1() { #[test] fn call_swap_chained_call_successful_2() { - let (post_states, chained_calls) = swap( + let (post_states, chained_calls) = swap_exact_input( AccountWithMetadataForTests::pool_definition_init(), AccountWithMetadataForTests::vault_a_init(), AccountWithMetadataForTests::vault_b_init(), @@ -2566,6 +2689,281 @@ fn call_swap_chained_call_successful_2() { ); } +#[should_panic(expected = "AccountId is not a token type for the pool")] +#[test] +fn call_swap_exact_output_incorrect_token_type() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_lp_definition_id(), + ); +} + +#[should_panic(expected = "Vault A was not provided")] +#[test] +fn call_swap_exact_output_vault_a_omitted() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_with_wrong_id(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Vault B was not provided")] +#[test] +fn call_swap_exact_output_vault_b_omitted() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_with_wrong_id(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Reserve for Token A exceeds vault balance")] +#[test] +fn call_swap_exact_output_reserves_vault_mismatch_1() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Reserve for Token B exceeds vault balance")] +#[test] +fn call_swap_exact_output_reserves_vault_mismatch_2() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Pool is inactive")] +#[test] +fn call_swap_exact_output_inactive() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_inactive(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Required input exceeds maximum amount in")] +#[test] +fn call_swap_exact_output_exceeds_max_in() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 166_u128, + 100_u128, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Exact amount out must be nonzero")] +#[test] +fn call_swap_exact_output_zero() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 0_u128, + 500_u128, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Exact amount out exceeds reserve")] +#[test] +fn call_swap_exact_output_exceeds_reserve() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::vault_b_reserve_init(), + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[test] +fn call_swap_exact_output_chained_call_successful() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::max_amount_in(), + BalanceForTests::vault_b_reserve_init(), + IdForTests::token_a_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_swap_exact_output_test_1().account + == *pool_post.account() + ); + + let chained_call_a = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_exact_output_token_a_test_1() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_exact_output_token_b_test_1() + ); +} + +#[test] +fn call_swap_exact_output_chained_call_successful_2() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 285, + 300, + IdForTests::token_b_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert!( + AccountWithMetadataForTests::pool_definition_swap_exact_output_test_2().account + == *pool_post.account() + ); + + let chained_call_a = chained_calls[1].clone(); + let chained_call_b = chained_calls[0].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_exact_output_token_a_test_2() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_exact_output_token_b_test_2() + ); +} + +// Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, +// making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes, +// so an attacker receives `exact_amount_out` tokens while paying nothing. +#[should_panic(expected = "reserve * amount_out overflows u128")] +#[test] +fn swap_exact_output_overflow_protection() { + // reserve_a chosen so that reserve_a * 2 overflows u128: + // (u128::MAX / 2 + 1) * 2 = u128::MAX + 1 → wraps to 0 + let large_reserve: u128 = u128::MAX / 2 + 1; + let reserve_b: u128 = 1_000; + + let pool = AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: 1, + reserve_a: large_reserve, + reserve_b, + fees: 0, + active: true, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + }; + + let vault_a = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: large_reserve, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::vault_a_id(), + }; + + let vault_b = AccountWithMetadata { + account: Account { + program_owner: TOKEN_PROGRAM_ID, + balance: 0, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: reserve_b, + }), + nonce: 0_u128.into(), + }, + is_authorized: true, + account_id: IdForTests::vault_b_id(), + }; + + let _result = swap_exact_output( + pool, + vault_a, + vault_b, + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, // exact_amount_out: small, valid (< reserve_b) + 1, // max_amount_in: tiny — real deposit would be enormous, but + // overflow wraps it to 0, making 0 <= 1 pass silently + IdForTests::token_a_definition_id(), + ); +} + #[test] fn new_definition_lp_asymmetric_amounts() { let (post_states, chained_calls) = new_definition( @@ -3064,7 +3462,7 @@ fn simple_amm_add() { fn simple_amm_swap_1() { let mut state = state_for_amm_tests(); - let instruction = amm_core::Instruction::Swap { + let instruction = amm_core::Instruction::SwapExactInput { swap_amount_in: BalanceForExeTests::swap_amount_in(), min_amount_out: BalanceForExeTests::swap_min_amount_out(), token_definition_id_in: IdForExeTests::token_b_definition_id(), @@ -3115,7 +3513,7 @@ fn simple_amm_swap_1() { fn simple_amm_swap_2() { let mut state = state_for_amm_tests(); - let instruction = amm_core::Instruction::Swap { + let instruction = amm_core::Instruction::SwapExactInput { swap_amount_in: BalanceForExeTests::swap_amount_in(), min_amount_out: BalanceForExeTests::swap_min_amount_out(), token_definition_id_in: IdForExeTests::token_a_definition_id(), diff --git a/test_program_methods/guest/src/bin/burner.rs b/test_program_methods/guest/src/bin/burner.rs index 991091c0..06ac9b6b 100644 --- a/test_program_methods/guest/src/bin/burner.rs +++ b/test_program_methods/guest/src/bin/burner.rs @@ -5,6 +5,7 @@ type Instruction = u128; fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: balance_to_burn, }, @@ -20,6 +21,7 @@ fn main() { account_post.balance = account_post.balance.saturating_sub(balance_to_burn); ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/chain_caller.rs b/test_program_methods/guest/src/bin/chain_caller.rs index c5780665..e8bf9d6f 100644 --- a/test_program_methods/guest/src/bin/chain_caller.rs +++ b/test_program_methods/guest/src/bin/chain_caller.rs @@ -13,6 +13,7 @@ type Instruction = (u128, ProgramId, u32, Option); fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (balance, auth_transfer_id, num_chain_calls, pda_seed), }, @@ -55,6 +56,7 @@ fn main() { } ProgramOutput::new( + self_program_id, instruction_words, vec![sender_pre.clone(), recipient_pre.clone()], vec![ diff --git a/test_program_methods/guest/src/bin/changer_claimer.rs b/test_program_methods/guest/src/bin/changer_claimer.rs index ee82ec16..c1bd886c 100644 --- a/test_program_methods/guest/src/bin/changer_claimer.rs +++ b/test_program_methods/guest/src/bin/changer_claimer.rs @@ -6,6 +6,7 @@ type Instruction = (Option>, bool); fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (data_opt, should_claim), }, @@ -33,5 +34,11 @@ fn main() { AccountPostState::new(account_post) }; - ProgramOutput::new(instruction_words, vec![pre], vec![post_state]).write(); + ProgramOutput::new( + self_program_id, + instruction_words, + vec![pre], + vec![post_state], + ) + .write(); } diff --git a/test_program_methods/guest/src/bin/claimer.rs b/test_program_methods/guest/src/bin/claimer.rs index e6239381..27b1ae73 100644 --- a/test_program_methods/guest/src/bin/claimer.rs +++ b/test_program_methods/guest/src/bin/claimer.rs @@ -5,6 +5,7 @@ type Instruction = (); fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (), }, @@ -17,5 +18,11 @@ fn main() { let account_post = AccountPostState::new_claimed(pre.account.clone(), Claim::Authorized); - ProgramOutput::new(instruction_words, vec![pre], vec![account_post]).write(); + ProgramOutput::new( + self_program_id, + instruction_words, + vec![pre], + vec![account_post], + ) + .write(); } diff --git a/test_program_methods/guest/src/bin/clock_chain_caller.rs b/test_program_methods/guest/src/bin/clock_chain_caller.rs index 913014c2..c6b2d386 100644 --- a/test_program_methods/guest/src/bin/clock_chain_caller.rs +++ b/test_program_methods/guest/src/bin/clock_chain_caller.rs @@ -11,6 +11,7 @@ type Instruction = (ProgramId, u64); // (clock_program_id, timestamp) fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (clock_program_id, timestamp), }, @@ -29,7 +30,7 @@ fn main() { pda_seeds: vec![], }; - ProgramOutput::new(instruction_words, pre_states, post_states) + ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states) .with_chained_calls(vec![chained_call]) .write(); } diff --git a/test_program_methods/guest/src/bin/data_changer.rs b/test_program_methods/guest/src/bin/data_changer.rs index 730a7180..ee7cb235 100644 --- a/test_program_methods/guest/src/bin/data_changer.rs +++ b/test_program_methods/guest/src/bin/data_changer.rs @@ -6,6 +6,7 @@ type Instruction = Vec; fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: data, }, @@ -23,6 +24,7 @@ fn main() { .expect("provided data should fit into data limit"); ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new_claimed( diff --git a/test_program_methods/guest/src/bin/extra_output.rs b/test_program_methods/guest/src/bin/extra_output.rs index 3adc591c..924f4d8f 100644 --- a/test_program_methods/guest/src/bin/extra_output.rs +++ b/test_program_methods/guest/src/bin/extra_output.rs @@ -6,7 +6,14 @@ use nssa_core::{ type Instruction = (); fn main() { - let (ProgramInput { pre_states, .. }, instruction_words) = read_nssa_inputs::(); + let ( + ProgramInput { + self_program_id, + pre_states, + .. + }, + instruction_words, + ) = read_nssa_inputs::(); let Ok([pre]) = <[_; 1]>::try_from(pre_states) else { return; @@ -15,6 +22,7 @@ fn main() { let account_pre = pre.account.clone(); ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![ diff --git a/test_program_methods/guest/src/bin/malicious_authorization_changer.rs b/test_program_methods/guest/src/bin/malicious_authorization_changer.rs index 7452d337..1db09a73 100644 --- a/test_program_methods/guest/src/bin/malicious_authorization_changer.rs +++ b/test_program_methods/guest/src/bin/malicious_authorization_changer.rs @@ -14,6 +14,7 @@ type Instruction = (u128, ProgramId); fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (balance, transfer_program_id), }, @@ -40,6 +41,7 @@ fn main() { }; ProgramOutput::new( + self_program_id, instruction_words, vec![sender.clone(), receiver.clone()], vec![ diff --git a/test_program_methods/guest/src/bin/minter.rs b/test_program_methods/guest/src/bin/minter.rs index ac29e4d3..445df32f 100644 --- a/test_program_methods/guest/src/bin/minter.rs +++ b/test_program_methods/guest/src/bin/minter.rs @@ -3,7 +3,14 @@ use nssa_core::program::{AccountPostState, ProgramInput, ProgramOutput, read_nss type Instruction = (); fn main() { - let (ProgramInput { pre_states, .. }, instruction_words) = read_nssa_inputs::(); + let ( + ProgramInput { + self_program_id, + pre_states, + .. + }, + instruction_words, + ) = read_nssa_inputs::(); let Ok([pre]) = <[_; 1]>::try_from(pre_states) else { return; @@ -17,6 +24,7 @@ fn main() { .expect("Balance overflow"); ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/missing_output.rs b/test_program_methods/guest/src/bin/missing_output.rs index b485e87a..6b33d95e 100644 --- a/test_program_methods/guest/src/bin/missing_output.rs +++ b/test_program_methods/guest/src/bin/missing_output.rs @@ -3,7 +3,14 @@ use nssa_core::program::{AccountPostState, ProgramInput, ProgramOutput, read_nss type Instruction = (); fn main() { - let (ProgramInput { pre_states, .. }, instruction_words) = read_nssa_inputs::(); + let ( + ProgramInput { + self_program_id, + pre_states, + .. + }, + instruction_words, + ) = read_nssa_inputs::(); let Ok([pre1, pre2]) = <[_; 2]>::try_from(pre_states) else { return; @@ -12,6 +19,7 @@ fn main() { let account_pre1 = pre1.account.clone(); ProgramOutput::new( + self_program_id, instruction_words, vec![pre1, pre2], vec![AccountPostState::new(account_pre1)], diff --git a/test_program_methods/guest/src/bin/modified_transfer.rs b/test_program_methods/guest/src/bin/modified_transfer.rs index a89c72fb..859f5cc0 100644 --- a/test_program_methods/guest/src/bin/modified_transfer.rs +++ b/test_program_methods/guest/src/bin/modified_transfer.rs @@ -64,6 +64,7 @@ fn main() { // Read input accounts. let ( ProgramInput { + self_program_id, pre_states, instruction: balance_to_move, }, @@ -80,5 +81,5 @@ fn main() { } _ => panic!("invalid params"), }; - ProgramOutput::new(instruction_data, pre_states, post_states).write(); + ProgramOutput::new(self_program_id, instruction_data, pre_states, post_states).write(); } diff --git a/test_program_methods/guest/src/bin/nonce_changer.rs b/test_program_methods/guest/src/bin/nonce_changer.rs index 0cecdc81..5e1cdbb2 100644 --- a/test_program_methods/guest/src/bin/nonce_changer.rs +++ b/test_program_methods/guest/src/bin/nonce_changer.rs @@ -3,7 +3,14 @@ use nssa_core::program::{AccountPostState, ProgramInput, ProgramOutput, read_nss type Instruction = (); fn main() { - let (ProgramInput { pre_states, .. }, instruction_words) = read_nssa_inputs::(); + let ( + ProgramInput { + self_program_id, + pre_states, + .. + }, + instruction_words, + ) = read_nssa_inputs::(); let Ok([pre]) = <[_; 1]>::try_from(pre_states) else { return; @@ -14,6 +21,7 @@ fn main() { account_post.nonce.public_account_nonce_increment(); ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/noop.rs b/test_program_methods/guest/src/bin/noop.rs index 35a07765..71787776 100644 --- a/test_program_methods/guest/src/bin/noop.rs +++ b/test_program_methods/guest/src/bin/noop.rs @@ -3,11 +3,18 @@ use nssa_core::program::{AccountPostState, ProgramInput, ProgramOutput, read_nss type Instruction = (); fn main() { - let (ProgramInput { pre_states, .. }, instruction_words) = read_nssa_inputs::(); + let ( + ProgramInput { + self_program_id, + pre_states, + .. + }, + instruction_words, + ) = read_nssa_inputs::(); let post_states = pre_states .iter() .map(|account| AccountPostState::new(account.account.clone())) .collect(); - ProgramOutput::new(instruction_words, pre_states, post_states).write(); + ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); } diff --git a/test_program_methods/guest/src/bin/program_owner_changer.rs b/test_program_methods/guest/src/bin/program_owner_changer.rs index 7e421351..f1b2cfce 100644 --- a/test_program_methods/guest/src/bin/program_owner_changer.rs +++ b/test_program_methods/guest/src/bin/program_owner_changer.rs @@ -3,7 +3,14 @@ use nssa_core::program::{AccountPostState, ProgramInput, ProgramOutput, read_nss type Instruction = (); fn main() { - let (ProgramInput { pre_states, .. }, instruction_words) = read_nssa_inputs::(); + let ( + ProgramInput { + self_program_id, + pre_states, + .. + }, + instruction_words, + ) = read_nssa_inputs::(); let Ok([pre]) = <[_; 1]>::try_from(pre_states) else { return; @@ -14,6 +21,7 @@ fn main() { account_post.program_owner = [0, 1, 2, 3, 4, 5, 6, 7]; ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new(account_post)], diff --git a/test_program_methods/guest/src/bin/simple_balance_transfer.rs b/test_program_methods/guest/src/bin/simple_balance_transfer.rs index 9ee715e8..4edd6198 100644 --- a/test_program_methods/guest/src/bin/simple_balance_transfer.rs +++ b/test_program_methods/guest/src/bin/simple_balance_transfer.rs @@ -5,6 +5,7 @@ type Instruction = u128; fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: balance, }, @@ -27,6 +28,7 @@ fn main() { .expect("Overflow when adding balance"); ProgramOutput::new( + self_program_id, instruction_words, vec![sender_pre, receiver_pre], vec![ diff --git a/test_program_methods/guest/src/bin/validity_window.rs b/test_program_methods/guest/src/bin/validity_window.rs index a0ff9f36..67908836 100644 --- a/test_program_methods/guest/src/bin/validity_window.rs +++ b/test_program_methods/guest/src/bin/validity_window.rs @@ -8,6 +8,7 @@ type Instruction = (BlockValidityWindow, TimestampValidityWindow); fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (block_validity_window, timestamp_validity_window), }, @@ -21,6 +22,7 @@ fn main() { let post = pre.account.clone(); ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new(post)], diff --git a/test_program_methods/guest/src/bin/validity_window_chain_caller.rs b/test_program_methods/guest/src/bin/validity_window_chain_caller.rs index 39f8ad69..cbe3c7c1 100644 --- a/test_program_methods/guest/src/bin/validity_window_chain_caller.rs +++ b/test_program_methods/guest/src/bin/validity_window_chain_caller.rs @@ -16,6 +16,7 @@ type Instruction = (BlockValidityWindow, ProgramId, BlockValidityWindow); fn main() { let ( ProgramInput { + self_program_id, pre_states, instruction: (block_validity_window, chained_program_id, chained_block_validity_window), }, @@ -38,6 +39,7 @@ fn main() { }; ProgramOutput::new( + self_program_id, instruction_words, vec![pre], vec![AccountPostState::new(post)], diff --git a/wallet-ffi/src/wallet.rs b/wallet-ffi/src/wallet.rs index 9117d0ee..93fc20aa 100644 --- a/wallet-ffi/src/wallet.rs +++ b/wallet-ffi/src/wallet.rs @@ -111,8 +111,8 @@ pub unsafe extern "C" fn wallet_ffi_create_new( return ptr::null_mut(); }; - match WalletCore::new_init_storage(config_path, storage_path, None, password) { - Ok(core) => { + match WalletCore::new_init_storage(config_path, storage_path, None, &password) { + Ok((core, _mnemonic)) => { let wrapper = Box::new(WalletWrapper { core: Mutex::new(core), }); diff --git a/wallet/Cargo.toml b/wallet/Cargo.toml index f77988a0..4e98b8ef 100644 --- a/wallet/Cargo.toml +++ b/wallet/Cargo.toml @@ -17,6 +17,7 @@ token_core.workspace = true amm_core.workspace = true testnet_initial_state.workspace = true ata_core.workspace = true +bip39.workspace = true anyhow.workspace = true thiserror.workspace = true diff --git a/wallet/src/chain_storage.rs b/wallet/src/chain_storage.rs index ebfe9896..3699609b 100644 --- a/wallet/src/chain_storage.rs +++ b/wallet/src/chain_storage.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, HashMap, btree_map::Entry}; use anyhow::Result; +use bip39::Mnemonic; use key_protocol::{ key_management::{ key_tree::{KeyTreePrivate, KeyTreePublic, chain_index::ChainIndex}, @@ -95,7 +96,7 @@ impl WalletChainStore { }) } - pub fn new_storage(config: WalletConfig, password: String) -> Result { + pub fn new_storage(config: WalletConfig, password: &str) -> Result<(Self, Mnemonic)> { let mut public_init_acc_map = BTreeMap::new(); let mut private_init_acc_map = BTreeMap::new(); @@ -121,13 +122,43 @@ impl WalletChainStore { } } - let public_tree = KeyTreePublic::new(&SeedHolder::new_mnemonic(password.clone())); - let private_tree = KeyTreePrivate::new(&SeedHolder::new_mnemonic(password)); + // TODO: Use password for storage encryption + let _ = password; + let (seed_holder, mnemonic) = SeedHolder::new_mnemonic(""); + let public_tree = KeyTreePublic::new(&seed_holder); + let private_tree = KeyTreePrivate::new(&seed_holder); + + Ok(( + Self { + user_data: NSSAUserData::new_with_accounts( + public_init_acc_map, + private_init_acc_map, + public_tree, + private_tree, + )?, + wallet_config: config, + labels: HashMap::new(), + }, + mnemonic, + )) + } + + /// Restore storage from an existing mnemonic phrase. + pub fn restore_storage( + config: WalletConfig, + mnemonic: &Mnemonic, + password: &str, + ) -> Result { + // TODO: Use password for storage encryption + let _ = password; + let seed_holder = SeedHolder::from_mnemonic(mnemonic, ""); + let public_tree = KeyTreePublic::new(&seed_holder); + let private_tree = KeyTreePrivate::new(&seed_holder); Ok(Self { user_data: NSSAUserData::new_with_accounts( - public_init_acc_map, - private_init_acc_map, + BTreeMap::new(), + BTreeMap::new(), public_tree, private_tree, )?, diff --git a/wallet/src/cli/mod.rs b/wallet/src/cli/mod.rs index 6463dee8..1653e938 100644 --- a/wallet/src/cli/mod.rs +++ b/wallet/src/cli/mod.rs @@ -1,6 +1,7 @@ -use std::{io::Write as _, path::PathBuf}; +use std::{io::Write as _, path::PathBuf, str::FromStr as _}; use anyhow::{Context as _, Result}; +use bip39::Mnemonic; use clap::{Parser, Subcommand}; use common::{HashType, transaction::NSSATransaction}; use futures::TryFutureExt as _; @@ -167,8 +168,9 @@ pub async fn execute_subcommand( config_subcommand.handle_subcommand(wallet_core).await? } Command::RestoreKeys { depth } => { + let mnemonic = read_mnemonic_from_stdin()?; let password = read_password_from_stdin()?; - wallet_core.reset_storage(password)?; + wallet_core.restore_storage(&mnemonic, &password)?; execute_keys_restoration(wallet_core, depth).await?; SubcommandReturnValue::Empty @@ -212,6 +214,16 @@ pub fn read_password_from_stdin() -> Result { Ok(password.trim().to_owned()) } +pub fn read_mnemonic_from_stdin() -> Result { + let mut phrase = String::new(); + + print!("Input recovery phrase: "); + std::io::stdout().flush()?; + std::io::stdin().read_line(&mut phrase)?; + + Mnemonic::from_str(phrase.trim()).context("Invalid mnemonic phrase") +} + pub async fn execute_keys_restoration(wallet_core: &mut WalletCore, depth: u32) -> Result<()> { wallet_core .storage diff --git a/wallet/src/cli/programs/amm.rs b/wallet/src/cli/programs/amm.rs index 7307569d..0b721d15 100644 --- a/wallet/src/cli/programs/amm.rs +++ b/wallet/src/cli/programs/amm.rs @@ -32,12 +32,12 @@ pub enum AmmProgramAgnosticSubcommand { #[arg(long)] balance_b: u128, }, - /// Swap. + /// Swap specifying exact input amount. /// /// The account associated with swapping token must be owned. /// /// Only public execution allowed. - Swap { + SwapExactInput { /// `user_holding_a` - valid 32 byte base58 string with privacy prefix. #[arg(long)] user_holding_a: String, @@ -52,6 +52,26 @@ pub enum AmmProgramAgnosticSubcommand { #[arg(long)] token_definition: String, }, + /// Swap specifying exact output amount. + /// + /// The account associated with swapping token must be owned. + /// + /// Only public execution allowed. + SwapExactOutput { + /// `user_holding_a` - valid 32 byte base58 string with privacy prefix. + #[arg(long)] + user_holding_a: String, + /// `user_holding_b` - valid 32 byte base58 string with privacy prefix. + #[arg(long)] + user_holding_b: String, + #[arg(long)] + exact_amount_out: u128, + #[arg(long)] + max_amount_in: u128, + /// `token_definition` - valid 32 byte base58 string WITHOUT privacy prefix. + #[arg(long)] + token_definition: String, + }, /// Add liquidity. /// /// `user_holding_a` and `user_holding_b` must be owned. @@ -150,7 +170,7 @@ impl WalletSubcommand for AmmProgramAgnosticSubcommand { } } } - Self::Swap { + Self::SwapExactInput { user_holding_a, user_holding_b, amount_in, @@ -168,7 +188,7 @@ impl WalletSubcommand for AmmProgramAgnosticSubcommand { match (user_holding_a_privacy, user_holding_b_privacy) { (AccountPrivacyKind::Public, AccountPrivacyKind::Public) => { Amm(wallet_core) - .send_swap( + .send_swap_exact_input( user_holding_a, user_holding_b, amount_in, @@ -185,6 +205,41 @@ impl WalletSubcommand for AmmProgramAgnosticSubcommand { } } } + Self::SwapExactOutput { + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition, + } => { + let (user_holding_a, user_holding_a_privacy) = + parse_addr_with_privacy_prefix(&user_holding_a)?; + let (user_holding_b, user_holding_b_privacy) = + parse_addr_with_privacy_prefix(&user_holding_b)?; + + let user_holding_a: AccountId = user_holding_a.parse()?; + let user_holding_b: AccountId = user_holding_b.parse()?; + + match (user_holding_a_privacy, user_holding_b_privacy) { + (AccountPrivacyKind::Public, AccountPrivacyKind::Public) => { + Amm(wallet_core) + .send_swap_exact_output( + user_holding_a, + user_holding_b, + exact_amount_out, + max_amount_in, + token_definition.parse()?, + ) + .await?; + + Ok(SubcommandReturnValue::Empty) + } + _ => { + // ToDo: Implement after private multi-chain calls is available + anyhow::bail!("Only public execution allowed for Amm calls"); + } + } + } Self::AddLiquidity { user_holding_a, user_holding_b, diff --git a/wallet/src/lib.rs b/wallet/src/lib.rs index a09d477e..63ea8611 100644 --- a/wallet/src/lib.rs +++ b/wallet/src/lib.rs @@ -11,6 +11,7 @@ use std::path::PathBuf; use anyhow::{Context as _, Result}; +use bip39::Mnemonic; use chain_storage::WalletChainStore; use common::{HashType, transaction::NSSATransaction}; use config::WalletConfig; @@ -117,15 +118,24 @@ impl WalletCore { config_path: PathBuf, storage_path: PathBuf, config_overrides: Option, - password: String, - ) -> Result { - Self::new( + password: &str, + ) -> Result<(Self, Mnemonic)> { + let mut mnemonic_out = None; + let wallet = Self::new( config_path, storage_path, config_overrides, - |config| WalletChainStore::new_storage(config, password), + |config| { + let (storage, mnemonic) = WalletChainStore::new_storage(config, password)?; + mnemonic_out = Some(mnemonic); + Ok(storage) + }, 0, - ) + )?; + Ok(( + wallet, + mnemonic_out.expect("mnemonic should be set after new_storage"), + )) } fn new( @@ -191,9 +201,13 @@ impl WalletCore { &self.storage } - /// Reset storage. - pub fn reset_storage(&mut self, password: String) -> Result<()> { - self.storage = WalletChainStore::new_storage(self.storage.wallet_config.clone(), password)?; + /// Restore storage from an existing mnemonic phrase. + pub fn restore_storage(&mut self, mnemonic: &Mnemonic, password: &str) -> Result<()> { + self.storage = WalletChainStore::restore_storage( + self.storage.wallet_config.clone(), + mnemonic, + password, + )?; Ok(()) } diff --git a/wallet/src/main.rs b/wallet/src/main.rs index e055bd63..cf8356db 100644 --- a/wallet/src/main.rs +++ b/wallet/src/main.rs @@ -46,13 +46,21 @@ async fn main() -> Result<()> { println!("Persistent storage not found, need to execute setup"); let password = read_password_from_stdin()?; - let wallet = WalletCore::new_init_storage( + let (wallet, mnemonic) = WalletCore::new_init_storage( config_path, storage_path, Some(config_overrides), - password, + &password, )?; + println!(); + println!("IMPORTANT: Write down your recovery phrase and store it securely."); + println!("This is the only way to recover your wallet if you lose access."); + println!(); + println!("Recovery phrase:"); + println!(" {mnemonic}"); + println!(); + wallet.store_persistent_data().await?; wallet }; diff --git a/wallet/src/program_facades/amm.rs b/wallet/src/program_facades/amm.rs index d68de7a5..b31d0658 100644 --- a/wallet/src/program_facades/amm.rs +++ b/wallet/src/program_facades/amm.rs @@ -121,7 +121,7 @@ impl Amm<'_> { .await?) } - pub async fn send_swap( + pub async fn send_swap_exact_input( &self, user_holding_a: AccountId, user_holding_b: AccountId, @@ -129,7 +129,7 @@ impl Amm<'_> { min_amount_out: u128, token_definition_id_in: AccountId, ) -> Result { - let instruction = amm_core::Instruction::Swap { + let instruction = amm_core::Instruction::SwapExactInput { swap_amount_in, min_amount_out, token_definition_id_in, @@ -168,34 +168,105 @@ impl Amm<'_> { user_holding_b, ]; - let account_id_auth; + let account_id_auth = if definition_token_a_id == token_definition_id_in { + user_holding_a + } else if definition_token_b_id == token_definition_id_in { + user_holding_b + } else { + return Err(ExecutionFailureKind::AccountDataError( + token_definition_id_in, + )); + }; - // Checking, which account are associated with TokenDefinition - let token_holder_acc_a = self + let nonces = self + .0 + .get_accounts_nonces(vec![account_id_auth]) + .await + .map_err(ExecutionFailureKind::SequencerError)?; + + let signing_key = self + .0 + .storage + .user_data + .get_pub_account_signing_key(account_id_auth) + .ok_or(ExecutionFailureKind::KeyNotFoundError)?; + + let message = nssa::public_transaction::Message::try_new( + program.id(), + account_ids, + nonces, + instruction, + ) + .unwrap(); + + let witness_set = + nssa::public_transaction::WitnessSet::for_message(&message, &[signing_key]); + + let tx = nssa::PublicTransaction::new(message, witness_set); + + Ok(self + .0 + .sequencer_client + .send_transaction(NSSATransaction::Public(tx)) + .await?) + } + + pub async fn send_swap_exact_output( + &self, + user_holding_a: AccountId, + user_holding_b: AccountId, + exact_amount_out: u128, + max_amount_in: u128, + token_definition_id_in: AccountId, + ) -> Result { + let instruction = amm_core::Instruction::SwapExactOutput { + exact_amount_out, + max_amount_in, + token_definition_id_in, + }; + let program = Program::amm(); + let amm_program_id = Program::amm().id(); + + let user_a_acc = self .0 .get_account_public(user_holding_a) .await .map_err(ExecutionFailureKind::SequencerError)?; - let token_holder_acc_b = self + let user_b_acc = self .0 .get_account_public(user_holding_b) .await .map_err(ExecutionFailureKind::SequencerError)?; - let token_holder_a = TokenHolding::try_from(&token_holder_acc_a.data) - .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_a))?; - let token_holder_b = TokenHolding::try_from(&token_holder_acc_b.data) - .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_b))?; + let definition_token_a_id = TokenHolding::try_from(&user_a_acc.data) + .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_a))? + .definition_id(); + let definition_token_b_id = TokenHolding::try_from(&user_b_acc.data) + .map_err(|_err| ExecutionFailureKind::AccountDataError(user_holding_b))? + .definition_id(); - if token_holder_a.definition_id() == token_definition_id_in { - account_id_auth = user_holding_a; - } else if token_holder_b.definition_id() == token_definition_id_in { - account_id_auth = user_holding_b; + let amm_pool = + compute_pool_pda(amm_program_id, definition_token_a_id, definition_token_b_id); + let vault_holding_a = compute_vault_pda(amm_program_id, amm_pool, definition_token_a_id); + let vault_holding_b = compute_vault_pda(amm_program_id, amm_pool, definition_token_b_id); + + let account_ids = vec![ + amm_pool, + vault_holding_a, + vault_holding_b, + user_holding_a, + user_holding_b, + ]; + + let account_id_auth = if definition_token_a_id == token_definition_id_in { + user_holding_a + } else if definition_token_b_id == token_definition_id_in { + user_holding_b } else { return Err(ExecutionFailureKind::AccountDataError( token_definition_id_in, )); - } + }; let nonces = self .0