diff --git a/nssa/src/error.rs b/nssa/src/error.rs index 565e02ba..65079d25 100644 --- a/nssa/src/error.rs +++ b/nssa/src/error.rs @@ -93,14 +93,8 @@ pub enum InvalidProgramBehaviorError { actual: Box, }, - #[error( - "Inconsistent authorization for account {account_id} : expected {expected_authorization}, actual {actual_authorization}" - )] - InconsistentAccountAuthorization { - account_id: AccountId, - expected_authorization: bool, - actual_authorization: bool, - }, + #[error("Unauthorized account marked as authorized")] + InvalidAccountAuthorization { account_id: AccountId }, #[error("Program ID mismatch: expected {expected:?}, actual {actual:?}")] MismatchedProgramId { diff --git a/nssa/src/validated_state_diff.rs b/nssa/src/validated_state_diff.rs index 068dc32c..4bd5fb05 100644 --- a/nssa/src/validated_state_diff.rs +++ b/nssa/src/validated_state_diff.rs @@ -8,7 +8,8 @@ use nssa_core::{ BlockId, Commitment, Nullifier, PrivacyPreservingCircuitOutput, Timestamp, account::{Account, AccountId, AccountWithMetadata}, program::{ - ChainedCall, Claim, DEFAULT_PROGRAM_ID, compute_public_authorized_pdas, validate_execution, + ChainedCall, Claim, DEFAULT_PROGRAM_ID, ProgramId, compute_public_authorized_pdas, + validate_execution, }, }; @@ -100,10 +101,26 @@ impl ValidatedStateDiff { pda_seeds: vec![], }; - let mut chained_calls = VecDeque::from_iter([(initial_call, None)]); + #[expect( + clippy::items_after_statements, + reason = "More readable to keep it behind the place where it's used" + )] + #[derive(Debug)] + struct CallerData { + program_id: Option, + authorized_accounts: HashSet, + } + + let initial_caller_data = CallerData { + program_id: None, + authorized_accounts: signer_account_ids.iter().copied().collect(), + }; + + let mut chained_calls = + VecDeque::<(ChainedCall, CallerData)>::from_iter([(initial_call, initial_caller_data)]); let mut chain_calls_counter = 0; - while let Some((chained_call, caller_program_id)) = chained_calls.pop_front() { + while let Some((chained_call, caller_data)) = chained_calls.pop_front() { ensure!( chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS, NssaError::MaxChainedCallsDepthExceeded @@ -119,7 +136,7 @@ impl ValidatedStateDiff { chained_call.program_id, chained_call.pre_states, chained_call.instruction_data ); let mut program_output = program.execute( - caller_program_id, + caller_data.program_id, &chained_call.pre_states, &chained_call.instruction_data, )?; @@ -129,10 +146,13 @@ impl ValidatedStateDiff { ); let authorized_pdas = - compute_public_authorized_pdas(caller_program_id, &chained_call.pda_seeds); + compute_public_authorized_pdas(caller_data.program_id, &chained_call.pda_seeds); + // Account is authorized if it is either in the caller's authorized accounts or in the + // list of PDAs the caller has authorized. let is_authorized = |account_id: &AccountId| { - signer_account_ids.contains(account_id) || authorized_pdas.contains(account_id) + authorized_pdas.contains(account_id) + || caller_data.authorized_accounts.contains(account_id) }; for pre in &program_output.pre_states { @@ -152,16 +172,12 @@ impl ValidatedStateDiff { } ); - // Check that authorization flags are consistent with the provided ones or - // authorized by program through the PDA mechanism - let expected_is_authorized = is_authorized(&account_id); + // Check that the program output pre_states marked as authorized are indeed + // authorized. + let is_indeed_authorized = is_authorized(&account_id); ensure!( - pre.is_authorized == expected_is_authorized, - InvalidProgramBehaviorError::InconsistentAccountAuthorization { - account_id, - expected_authorization: expected_is_authorized, - actual_authorization: pre.is_authorized - } + !pre.is_authorized || is_indeed_authorized, + InvalidProgramBehaviorError::InvalidAccountAuthorization { account_id } ); } @@ -176,9 +192,9 @@ impl ValidatedStateDiff { // Verify that the program output's caller_program_id matches the actual caller. ensure!( - program_output.caller_program_id == caller_program_id, + program_output.caller_program_id == caller_data.program_id, InvalidProgramBehaviorError::MismatchedCallerProgramId { - expected: caller_program_id, + expected: caller_data.program_id, actual: program_output.caller_program_id, } ); @@ -205,7 +221,8 @@ impl ValidatedStateDiff { let Some(claim) = post.required_claim() else { continue; }; - let account_id = program_output.pre_states[i].account_id; + let pre = &program_output.pre_states[i]; + let account_id = pre.account_id; // The invoked program can only claim accounts with default program id. ensure!( @@ -217,7 +234,7 @@ impl ValidatedStateDiff { Claim::Authorized => { // The program can only claim accounts that were authorized by the signer. ensure!( - is_authorized(&account_id), + pre.is_authorized, InvalidProgramBehaviorError::ClaimedUnauthorizedAccount { account_id } ); } @@ -248,8 +265,20 @@ impl ValidatedStateDiff { state_diff.insert(pre.account_id, post.account().clone()); } + let authorized_accounts: HashSet<_> = chained_call + .pre_states + .iter() + .filter(|pre| pre.is_authorized) + .map(|pre| pre.account_id) + .collect(); for new_call in program_output.chained_calls.into_iter().rev() { - chained_calls.push_front((new_call, Some(chained_call.program_id))); + chained_calls.push_front(( + new_call, + CallerData { + program_id: Some(chained_call.program_id), + authorized_accounts: authorized_accounts.clone(), + }, + )); } chain_calls_counter = chain_calls_counter