add pda mechanism

This commit is contained in:
Sergio Chouhy 2025-11-27 13:10:38 -03:00
parent d82f06593d
commit 3fbf1e1fec
2 changed files with 51 additions and 8 deletions

View File

@ -1,7 +1,7 @@
use risc0_zkvm::{DeserializeOwned, guest::env, serde::Deserializer};
use serde::{Deserialize, Serialize};
use crate::account::{Account, AccountWithMetadata};
use crate::account::{Account, AccountId, AccountWithMetadata};
pub type ProgramId = [u32; 8];
pub type InstructionData = Vec<u32>;
@ -16,13 +16,35 @@ pub struct ProgramInput<T> {
#[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))]
pub struct PdaSeed([u8; 32]);
#[cfg(feature = "host")]
impl From<(&ProgramId, &PdaSeed)> for AccountId {
fn from(value: (&ProgramId, &PdaSeed)) -> Self {
use risc0_zkvm::sha::{Impl, Sha256};
const PROGRAM_DERIVED_ACCOUNT_ID_PREFIX: &[u8; 32] =
b"/NSSA/v0.2/AccountId/PDA/\x00\x00\x00\x00\x00\x00\x00";
let mut bytes = [0; 96];
bytes[0..32].copy_from_slice(PROGRAM_DERIVED_ACCOUNT_ID_PREFIX);
let program_id_bytes: &[u8] =
bytemuck::try_cast_slice(value.0).expect("ProgramId should be castable to &[u8]");
bytes[32..64].copy_from_slice(program_id_bytes);
bytes[64..].copy_from_slice(&value.1.0);
AccountId::new(
Impl::hash_bytes(&bytes)
.as_bytes()
.try_into()
.expect("Hash output must be exactly 32 bytes long"),
)
}
}
#[derive(Serialize, Deserialize, Clone)]
#[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))]
pub struct ChainedCall {
pub program_id: ProgramId,
pub instruction_data: InstructionData,
pub pre_states: Vec<AccountWithMetadata>,
pub pda_seeds: Vec<PdaSeed>
pub pda_seeds: Vec<PdaSeed>,
}
#[derive(Serialize, Deserialize, Clone)]

View File

@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet, VecDeque};
use borsh::{BorshDeserialize, BorshSerialize};
use nssa_core::{
account::{Account, AccountId, AccountWithMetadata},
program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
program::{ChainedCall, DEFAULT_PROGRAM_ID, PdaSeed, ProgramId, validate_execution},
};
use sha2::{Digest, digest::FixedOutput};
@ -110,10 +110,10 @@ impl PublicTransaction {
pda_seeds: vec![],
};
let mut chained_calls = VecDeque::from_iter([initial_call]);
let mut chained_calls = VecDeque::from_iter([(initial_call, None)]);
let mut chain_calls_counter = 0;
while let Some(chained_call) = chained_calls.pop_front() {
while let Some((chained_call, caller_program_id)) = chained_calls.pop_front() {
if chain_calls_counter > MAX_NUMBER_CHAINED_CALLS {
return Err(NssaError::MaxChainedCallsDepthExceeded);
}
@ -126,6 +126,9 @@ impl PublicTransaction {
let mut program_output =
program.execute(&chained_call.pre_states, &chained_call.instruction_data)?;
let authorized_pdas =
self.compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds);
for pre in &program_output.pre_states {
let account_id = pre.account_id;
// Check that the program output pre_states coinicide with the values in the public
@ -138,8 +141,11 @@ impl PublicTransaction {
return Err(NssaError::InvalidProgramBehavior);
}
// Check that authorization flags are consistent with the provided ones
if pre.is_authorized && !signer_account_ids.contains(&account_id) {
// Check that authorization flags are consistent with the provided ones or
// authorized by program through the PDA mechanism
let is_authorized = signer_account_ids.contains(&account_id)
|| authorized_pdas.contains(&account_id);
if pre.is_authorized && !is_authorized {
return Err(NssaError::InvalidProgramBehavior);
}
}
@ -171,7 +177,7 @@ impl PublicTransaction {
}
for new_call in program_output.chained_calls.into_iter().rev() {
chained_calls.push_front(new_call);
chained_calls.push_front((new_call, Some(chained_call.program_id)));
}
chain_calls_counter += 1;
@ -179,6 +185,21 @@ impl PublicTransaction {
Ok(state_diff)
}
fn compute_authorized_pdas(
&self,
caller_program_id: &Option<ProgramId>,
pda_seeds: &[PdaSeed],
) -> HashSet<AccountId> {
if let Some(caller_program_id) = caller_program_id {
pda_seeds
.iter()
.map(|pda_seed| AccountId::from((caller_program_id, pda_seed)))
.collect()
} else {
HashSet::new()
}
}
}
#[cfg(test)]