diff --git a/key_protocol/src/key_protocol_core/mod.rs b/key_protocol/src/key_protocol_core/mod.rs index e17c35a7..7218ebde 100644 --- a/key_protocol/src/key_protocol_core/mod.rs +++ b/key_protocol/src/key_protocol_core/mod.rs @@ -22,11 +22,15 @@ pub struct UserPrivateAccountData { } /// Metadata for a shared account (GMS-derived), stored alongside the cached plaintext state. -/// The group label and identifier are needed to re-derive keys during sync. +/// The group label and identifier (or PDA seed) are needed to re-derive keys during sync. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SharedAccountEntry { pub group_label: String, pub identifier: Identifier, + /// For PDA accounts, the seed used to derive keys via `derive_keys_for_pda`. + /// `None` for regular shared accounts (keys derived from identifier via tag). + #[serde(default)] + pub pda_seed: Option, pub account: Account, } diff --git a/wallet/src/cli/account.rs b/wallet/src/cli/account.rs index 1355eb69..3bb7310b 100644 --- a/wallet/src/cli/account.rs +++ b/wallet/src/cli/account.rs @@ -220,6 +220,7 @@ impl WalletSubcommand for NewSubcommand { account_id, group_name.clone(), u128::MAX, + Some(pda_seed), ); println!("PDA shared account from group '{group_name}'"); @@ -259,6 +260,7 @@ impl WalletSubcommand for NewSubcommand { account_id, group_name.clone(), identifier, + None, ); println!("Shared account from group '{group_name}'"); diff --git a/wallet/src/lib.rs b/wallet/src/lib.rs index 7a293139..f179ec44 100644 --- a/wallet/src/lib.rs +++ b/wallet/src/lib.rs @@ -310,6 +310,7 @@ impl WalletCore { account_id: AccountId, group_label: String, identifier: nssa_core::Identifier, + pda_seed: Option, ) { use key_protocol::key_protocol_core::SharedAccountEntry; self.storage.user_data.shared_accounts.insert( @@ -317,6 +318,7 @@ impl WalletCore { SharedAccountEntry { group_label, identifier, + pda_seed, account: Account::default(), }, ); @@ -592,6 +594,77 @@ impl WalletCore { self.storage .insert_private_account_data(affected_account_id, identifier, new_acc); } + + // Scan for updates to shared accounts (GMS-derived). + self.sync_shared_accounts_with_tx(&tx); + } + + fn sync_shared_accounts_with_tx(&mut self, tx: &PrivacyPreservingTransaction) { + let shared_keys: Vec<_> = self + .storage + .user_data + .shared_accounts + .iter() + .filter_map(|(&account_id, entry)| { + let holder = self + .storage + .user_data + .group_key_holders + .get(&entry.group_label)?; + + let keys = entry.pda_seed.as_ref().map_or_else( + || { + let tag = { + use sha2::Digest as _; + let mut hasher = sha2::Sha256::new(); + hasher.update(b"/LEE/v0.3/SharedAccountTag/\x00\x00\x00\x00\x00"); + hasher.update(entry.identifier.to_le_bytes()); + let result: [u8; 32] = hasher.finalize().into(); + result + }; + holder.derive_keys_for_shared_account(&tag) + }, + |pda_seed| holder.derive_keys_for_pda(pda_seed), + ); + let npk = keys.generate_nullifier_public_key(); + let vpk = keys.generate_viewing_public_key(); + let vsk = keys.viewing_secret_key; + Some((account_id, npk, vpk, vsk)) + }) + .collect(); + + for (account_id, npk, vpk, vsk) in shared_keys { + let view_tag = EncryptedAccountData::compute_view_tag(&npk, &vpk); + + for (ciph_id, encrypted_data) in tx + .message() + .encrypted_private_post_states + .iter() + .enumerate() + { + if encrypted_data.view_tag != view_tag { + continue; + } + + let shared_secret = SharedSecretKey::new(&vsk, &encrypted_data.epk); + let commitment = &tx.message.new_commitments[ciph_id]; + + if let Some((_decrypted_identifier, new_acc)) = nssa_core::EncryptionScheme::decrypt( + &encrypted_data.ciphertext, + &shared_secret, + commitment, + ciph_id + .try_into() + .expect("Ciphertext ID is expected to fit in u32"), + ) { + info!("Synced shared account {account_id:#?} with new state {new_acc:#?}"); + if let Some(entry) = self.storage.user_data.shared_accounts.get_mut(&account_id) + { + entry.account = new_acc; + } + } + } + } } #[must_use]