From 3fbf1e1fec33e57de6c0018d695ffa4f695e5033 Mon Sep 17 00:00:00 2001 From: Sergio Chouhy Date: Thu, 27 Nov 2025 13:10:38 -0300 Subject: [PATCH] add pda mechanism --- nssa/core/src/program.rs | 26 +++++++++++++++-- nssa/src/public_transaction/transaction.rs | 33 ++++++++++++++++++---- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 927d5fc..ad9bbab 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -1,7 +1,7 @@ use risc0_zkvm::{DeserializeOwned, guest::env, serde::Deserializer}; use serde::{Deserialize, Serialize}; -use crate::account::{Account, AccountWithMetadata}; +use crate::account::{Account, AccountId, AccountWithMetadata}; pub type ProgramId = [u32; 8]; pub type InstructionData = Vec; @@ -16,13 +16,35 @@ pub struct ProgramInput { #[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))] pub struct PdaSeed([u8; 32]); +#[cfg(feature = "host")] +impl From<(&ProgramId, &PdaSeed)> for AccountId { + fn from(value: (&ProgramId, &PdaSeed)) -> Self { + use risc0_zkvm::sha::{Impl, Sha256}; + const PROGRAM_DERIVED_ACCOUNT_ID_PREFIX: &[u8; 32] = + b"/NSSA/v0.2/AccountId/PDA/\x00\x00\x00\x00\x00\x00\x00"; + + let mut bytes = [0; 96]; + bytes[0..32].copy_from_slice(PROGRAM_DERIVED_ACCOUNT_ID_PREFIX); + let program_id_bytes: &[u8] = + bytemuck::try_cast_slice(value.0).expect("ProgramId should be castable to &[u8]"); + bytes[32..64].copy_from_slice(program_id_bytes); + bytes[64..].copy_from_slice(&value.1.0); + AccountId::new( + Impl::hash_bytes(&bytes) + .as_bytes() + .try_into() + .expect("Hash output must be exactly 32 bytes long"), + ) + } +} + #[derive(Serialize, Deserialize, Clone)] #[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))] pub struct ChainedCall { pub program_id: ProgramId, pub instruction_data: InstructionData, pub pre_states: Vec, - pub pda_seeds: Vec + pub pda_seeds: Vec, } #[derive(Serialize, Deserialize, Clone)] diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 081fe2f..cafa27b 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; use borsh::{BorshDeserialize, BorshSerialize}; use nssa_core::{ account::{Account, AccountId, AccountWithMetadata}, - program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution}, + program::{ChainedCall, DEFAULT_PROGRAM_ID, PdaSeed, ProgramId, validate_execution}, }; use sha2::{Digest, digest::FixedOutput}; @@ -110,10 +110,10 @@ impl PublicTransaction { pda_seeds: vec![], }; - let mut chained_calls = VecDeque::from_iter([initial_call]); + let mut chained_calls = VecDeque::from_iter([(initial_call, None)]); let mut chain_calls_counter = 0; - while let Some(chained_call) = chained_calls.pop_front() { + while let Some((chained_call, caller_program_id)) = chained_calls.pop_front() { if chain_calls_counter > MAX_NUMBER_CHAINED_CALLS { return Err(NssaError::MaxChainedCallsDepthExceeded); } @@ -126,6 +126,9 @@ impl PublicTransaction { let mut program_output = program.execute(&chained_call.pre_states, &chained_call.instruction_data)?; + let authorized_pdas = + self.compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds); + for pre in &program_output.pre_states { let account_id = pre.account_id; // Check that the program output pre_states coinicide with the values in the public @@ -138,8 +141,11 @@ impl PublicTransaction { return Err(NssaError::InvalidProgramBehavior); } - // Check that authorization flags are consistent with the provided ones - if pre.is_authorized && !signer_account_ids.contains(&account_id) { + // Check that authorization flags are consistent with the provided ones or + // authorized by program through the PDA mechanism + let is_authorized = signer_account_ids.contains(&account_id) + || authorized_pdas.contains(&account_id); + if pre.is_authorized && !is_authorized { return Err(NssaError::InvalidProgramBehavior); } } @@ -171,7 +177,7 @@ impl PublicTransaction { } for new_call in program_output.chained_calls.into_iter().rev() { - chained_calls.push_front(new_call); + chained_calls.push_front((new_call, Some(chained_call.program_id))); } chain_calls_counter += 1; @@ -179,6 +185,21 @@ impl PublicTransaction { Ok(state_diff) } + + fn compute_authorized_pdas( + &self, + caller_program_id: &Option, + pda_seeds: &[PdaSeed], + ) -> HashSet { + if let Some(caller_program_id) = caller_program_id { + pda_seeds + .iter() + .map(|pda_seed| AccountId::from((caller_program_id, pda_seed))) + .collect() + } else { + HashSet::new() + } + } } #[cfg(test)]