fix: proper account authorization propagation

This commit is contained in:
Daniil Polyakov 2026-05-14 02:01:12 +03:00
parent ee5a98fc48
commit f721a00bdf
2 changed files with 51 additions and 28 deletions

View File

@ -93,14 +93,8 @@ pub enum InvalidProgramBehaviorError {
actual: Box<Account>, actual: Box<Account>,
}, },
#[error( #[error("Unauthorized account marked as authorized")]
"Inconsistent authorization for account {account_id} : expected {expected_authorization}, actual {actual_authorization}" InvalidAccountAuthorization { account_id: AccountId },
)]
InconsistentAccountAuthorization {
account_id: AccountId,
expected_authorization: bool,
actual_authorization: bool,
},
#[error("Program ID mismatch: expected {expected:?}, actual {actual:?}")] #[error("Program ID mismatch: expected {expected:?}, actual {actual:?}")]
MismatchedProgramId { MismatchedProgramId {

View File

@ -8,7 +8,8 @@ use nssa_core::{
BlockId, Commitment, Nullifier, PrivacyPreservingCircuitOutput, Timestamp, BlockId, Commitment, Nullifier, PrivacyPreservingCircuitOutput, Timestamp,
account::{Account, AccountId, AccountWithMetadata}, account::{Account, AccountId, AccountWithMetadata},
program::{ program::{
ChainedCall, Claim, DEFAULT_PROGRAM_ID, compute_public_authorized_pdas, validate_execution, ChainedCall, Claim, DEFAULT_PROGRAM_ID, ProgramId, compute_public_authorized_pdas,
validate_execution,
}, },
}; };
@ -100,10 +101,26 @@ impl ValidatedStateDiff {
pda_seeds: vec![], pda_seeds: vec![],
}; };
let mut chained_calls = VecDeque::from_iter([(initial_call, None)]); #[expect(
clippy::items_after_statements,
reason = "More readable to keep it behind the place where it's used"
)]
#[derive(Debug)]
struct CallerData {
program_id: Option<ProgramId>,
authorized_accounts: HashSet<AccountId>,
}
let initial_caller_data = CallerData {
program_id: None,
authorized_accounts: signer_account_ids.iter().copied().collect(),
};
let mut chained_calls =
VecDeque::<(ChainedCall, CallerData)>::from_iter([(initial_call, initial_caller_data)]);
let mut chain_calls_counter = 0; let mut chain_calls_counter = 0;
while let Some((chained_call, caller_program_id)) = chained_calls.pop_front() { while let Some((chained_call, caller_data)) = chained_calls.pop_front() {
ensure!( ensure!(
chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS, chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS,
NssaError::MaxChainedCallsDepthExceeded NssaError::MaxChainedCallsDepthExceeded
@ -119,7 +136,7 @@ impl ValidatedStateDiff {
chained_call.program_id, chained_call.pre_states, chained_call.instruction_data chained_call.program_id, chained_call.pre_states, chained_call.instruction_data
); );
let mut program_output = program.execute( let mut program_output = program.execute(
caller_program_id, caller_data.program_id,
&chained_call.pre_states, &chained_call.pre_states,
&chained_call.instruction_data, &chained_call.instruction_data,
)?; )?;
@ -129,10 +146,13 @@ impl ValidatedStateDiff {
); );
let authorized_pdas = let authorized_pdas =
compute_public_authorized_pdas(caller_program_id, &chained_call.pda_seeds); compute_public_authorized_pdas(caller_data.program_id, &chained_call.pda_seeds);
// Account is authorized if it is either in the caller's authorized accounts or in the
// list of PDAs the caller has authorized.
let is_authorized = |account_id: &AccountId| { let is_authorized = |account_id: &AccountId| {
signer_account_ids.contains(account_id) || authorized_pdas.contains(account_id) authorized_pdas.contains(account_id)
|| caller_data.authorized_accounts.contains(account_id)
}; };
for pre in &program_output.pre_states { for pre in &program_output.pre_states {
@ -152,16 +172,12 @@ impl ValidatedStateDiff {
} }
); );
// Check that authorization flags are consistent with the provided ones or // Check that the program output pre_states marked as authorized are indeed
// authorized by program through the PDA mechanism // authorized.
let expected_is_authorized = is_authorized(&account_id); let is_indeed_authorized = is_authorized(&account_id);
ensure!( ensure!(
pre.is_authorized == expected_is_authorized, !pre.is_authorized || is_indeed_authorized,
InvalidProgramBehaviorError::InconsistentAccountAuthorization { InvalidProgramBehaviorError::InvalidAccountAuthorization { account_id }
account_id,
expected_authorization: expected_is_authorized,
actual_authorization: pre.is_authorized
}
); );
} }
@ -176,9 +192,9 @@ impl ValidatedStateDiff {
// Verify that the program output's caller_program_id matches the actual caller. // Verify that the program output's caller_program_id matches the actual caller.
ensure!( ensure!(
program_output.caller_program_id == caller_program_id, program_output.caller_program_id == caller_data.program_id,
InvalidProgramBehaviorError::MismatchedCallerProgramId { InvalidProgramBehaviorError::MismatchedCallerProgramId {
expected: caller_program_id, expected: caller_data.program_id,
actual: program_output.caller_program_id, actual: program_output.caller_program_id,
} }
); );
@ -205,7 +221,8 @@ impl ValidatedStateDiff {
let Some(claim) = post.required_claim() else { let Some(claim) = post.required_claim() else {
continue; continue;
}; };
let account_id = program_output.pre_states[i].account_id; let pre = &program_output.pre_states[i];
let account_id = pre.account_id;
// The invoked program can only claim accounts with default program id. // The invoked program can only claim accounts with default program id.
ensure!( ensure!(
@ -217,7 +234,7 @@ impl ValidatedStateDiff {
Claim::Authorized => { Claim::Authorized => {
// The program can only claim accounts that were authorized by the signer. // The program can only claim accounts that were authorized by the signer.
ensure!( ensure!(
is_authorized(&account_id), pre.is_authorized,
InvalidProgramBehaviorError::ClaimedUnauthorizedAccount { account_id } InvalidProgramBehaviorError::ClaimedUnauthorizedAccount { account_id }
); );
} }
@ -248,8 +265,20 @@ impl ValidatedStateDiff {
state_diff.insert(pre.account_id, post.account().clone()); state_diff.insert(pre.account_id, post.account().clone());
} }
let authorized_accounts: HashSet<_> = chained_call
.pre_states
.iter()
.filter(|pre| pre.is_authorized)
.map(|pre| pre.account_id)
.collect();
for new_call in program_output.chained_calls.into_iter().rev() { for new_call in program_output.chained_calls.into_iter().rev() {
chained_calls.push_front((new_call, Some(chained_call.program_id))); chained_calls.push_front((
new_call,
CallerData {
program_id: Some(chained_call.program_id),
authorized_accounts: authorized_accounts.clone(),
},
));
} }
chain_calls_counter = chain_calls_counter chain_calls_counter = chain_calls_counter