Merge branch 'schouhy/add-multi-chain-calls' into schouhy/implement-pda-for-public-accounts

This commit is contained in:
Sergio Chouhy 2025-11-27 12:05:09 -03:00
commit a3e07347a4
7 changed files with 119 additions and 71 deletions

View File

@ -1213,7 +1213,6 @@ pub fn prepare_function_map() -> HashMap<String, TestFunction> {
// }; // };
// let tx = fetch_privacy_preserving_tx(&seq_client, tx_hash.clone()).await; // let tx = fetch_privacy_preserving_tx(&seq_client, tx_hash.clone()).await;
// println!("Waiting for next blocks to check if continoius run fetch account"); // println!("Waiting for next blocks to check if continoius run fetch account");
// tokio::time::sleep(Duration::from_secs(TIME_TO_WAIT_FOR_BLOCK_SECONDS)).await; // tokio::time::sleep(Duration::from_secs(TIME_TO_WAIT_FOR_BLOCK_SECONDS)).await;
// tokio::time::sleep(Duration::from_secs(TIME_TO_WAIT_FOR_BLOCK_SECONDS)).await; // tokio::time::sleep(Duration::from_secs(TIME_TO_WAIT_FOR_BLOCK_SECONDS)).await;
@ -1376,6 +1375,7 @@ pub fn prepare_function_map() -> HashMap<String, TestFunction> {
pub async fn test_pinata() { pub async fn test_pinata() {
info!("########## test_pinata ##########"); info!("########## test_pinata ##########");
let pinata_account_id = PINATA_BASE58; let pinata_account_id = PINATA_BASE58;
let pinata_prize = 150; let pinata_prize = 150;
let solution = 989106; let solution = 989106;
let command = Command::Pinata(PinataProgramAgnosticSubcommand::Claim { let command = Command::Pinata(PinataProgramAgnosticSubcommand::Claim {

View File

@ -17,7 +17,7 @@ pub struct ProgramInput<T> {
pub struct ChainedCall { pub struct ChainedCall {
pub program_id: ProgramId, pub program_id: ProgramId,
pub instruction_data: InstructionData, pub instruction_data: InstructionData,
pub account_indices: Vec<usize>, pub pre_states: Vec<AccountWithMetadata>,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@ -25,7 +25,7 @@ pub struct ChainedCall {
pub struct ProgramOutput { pub struct ProgramOutput {
pub pre_states: Vec<AccountWithMetadata>, pub pre_states: Vec<AccountWithMetadata>,
pub post_states: Vec<Account>, pub post_states: Vec<Account>,
pub chained_call: Option<ChainedCall>, pub chained_calls: Vec<ChainedCall>,
} }
pub fn read_nssa_inputs<T: DeserializeOwned>() -> ProgramInput<T> { pub fn read_nssa_inputs<T: DeserializeOwned>() -> ProgramInput<T> {
@ -42,7 +42,7 @@ pub fn write_nssa_outputs(pre_states: Vec<AccountWithMetadata>, post_states: Vec
let output = ProgramOutput { let output = ProgramOutput {
pre_states, pre_states,
post_states, post_states,
chained_call: None, chained_calls: Vec::new(),
}; };
env::commit(&output); env::commit(&output);
} }
@ -50,12 +50,12 @@ pub fn write_nssa_outputs(pre_states: Vec<AccountWithMetadata>, post_states: Vec
pub fn write_nssa_outputs_with_chained_call( pub fn write_nssa_outputs_with_chained_call(
pre_states: Vec<AccountWithMetadata>, pre_states: Vec<AccountWithMetadata>,
post_states: Vec<Account>, post_states: Vec<Account>,
chained_call: Option<ChainedCall>, chained_calls: Vec<ChainedCall>,
) { ) {
let output = ProgramOutput { let output = ProgramOutput {
pre_states, pre_states,
post_states, post_states,
chained_call, chained_calls,
}; };
env::commit(&output); env::commit(&output);
} }

View File

@ -27,11 +27,11 @@ fn main() {
let ProgramOutput { let ProgramOutput {
pre_states, pre_states,
post_states, post_states,
chained_call, chained_calls,
} = program_output; } = program_output;
// TODO: implement chained calls for privacy preserving transactions // TODO: implement chained calls for privacy preserving transactions
if chained_call.is_some() { if !chained_calls.is_empty() {
panic!("Privacy preserving transactions do not support yet chained calls.") panic!("Privacy preserving transactions do not support yet chained calls.")
} }

View File

@ -54,4 +54,7 @@ pub enum NssaError {
#[error("Program already exists")] #[error("Program already exists")]
ProgramAlreadyExists, ProgramAlreadyExists,
#[error("Chain of calls is too long")]
MaxChainedCallsDepthExceeded,
} }

View File

@ -1,9 +1,9 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet, VecDeque};
use borsh::{BorshDeserialize, BorshSerialize}; use borsh::{BorshDeserialize, BorshSerialize};
use nssa_core::{ use nssa_core::{
account::{Account, AccountId, AccountWithMetadata}, account::{Account, AccountId, AccountWithMetadata},
program::{DEFAULT_PROGRAM_ID, validate_execution}, program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
}; };
use sha2::{Digest, digest::FixedOutput}; use sha2::{Digest, digest::FixedOutput};
@ -11,6 +11,7 @@ use crate::{
V02State, V02State,
error::NssaError, error::NssaError,
public_transaction::{Message, WitnessSet}, public_transaction::{Message, WitnessSet},
state::MAX_NUMBER_CHAINED_CALLS,
}; };
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] #[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
@ -18,7 +19,6 @@ pub struct PublicTransaction {
message: Message, message: Message,
witness_set: WitnessSet, witness_set: WitnessSet,
} }
const MAX_NUMBER_CHAINED_CALLS: usize = 10;
impl PublicTransaction { impl PublicTransaction {
pub fn new(message: Message, witness_set: WitnessSet) -> Self { pub fn new(message: Message, witness_set: WitnessSet) -> Self {
@ -89,7 +89,7 @@ impl PublicTransaction {
} }
// Build pre_states for execution // Build pre_states for execution
let mut input_pre_states: Vec<_> = message let input_pre_states: Vec<_> = message
.account_ids .account_ids
.iter() .iter()
.map(|account_id| { .map(|account_id| {
@ -103,22 +103,44 @@ impl PublicTransaction {
let mut state_diff: HashMap<AccountId, Account> = HashMap::new(); let mut state_diff: HashMap<AccountId, Account> = HashMap::new();
let mut program_id = message.program_id; let initial_call = ChainedCall {
let mut instruction_data = message.instruction_data.clone(); program_id: message.program_id,
instruction_data: message.instruction_data.clone(),
pre_states: input_pre_states,
};
let mut chained_calls = VecDeque::from_iter([initial_call]);
let mut chain_calls_counter = 0;
while let Some(chained_call) = chained_calls.pop_front() {
if chain_calls_counter > MAX_NUMBER_CHAINED_CALLS {
return Err(NssaError::MaxChainedCallsDepthExceeded);
}
for _i in 0..MAX_NUMBER_CHAINED_CALLS {
// Check the `program_id` corresponds to a deployed program // Check the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&program_id) else { let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into())); return Err(NssaError::InvalidInput("Unknown program".into()));
}; };
let mut program_output = program.execute(&input_pre_states, &instruction_data)?; let mut program_output =
program.execute(&chained_call.pre_states, &chained_call.instruction_data)?;
// This check is equivalent to checking that the program output pre_states coinicide for pre in &program_output.pre_states {
// with the values in the public state or with any modifications to those values let account_id = pre.account_id;
// during the chain of calls. // Check that the program output pre_states coinicide with the values in the public
if input_pre_states != program_output.pre_states { // state or with any modifications to those values during the chain of calls.
return Err(NssaError::InvalidProgramBehavior); let expected_pre = state_diff
.get(&account_id)
.cloned()
.unwrap_or_else(|| state.get_account_by_id(&account_id));
if pre.account != expected_pre {
return Err(NssaError::InvalidProgramBehavior);
}
// Check that authorization flags are consistent with the provided ones
if pre.is_authorized && !signer_account_ids.contains(&account_id) {
return Err(NssaError::InvalidProgramBehavior);
}
} }
// Verify execution corresponds to a well-behaved program. // Verify execution corresponds to a well-behaved program.
@ -126,7 +148,7 @@ impl PublicTransaction {
if !validate_execution( if !validate_execution(
&program_output.pre_states, &program_output.pre_states,
&program_output.post_states, &program_output.post_states,
program_id, chained_call.program_id,
) { ) {
return Err(NssaError::InvalidProgramBehavior); return Err(NssaError::InvalidProgramBehavior);
} }
@ -134,7 +156,7 @@ impl PublicTransaction {
// The invoked program claims the accounts with default program id. // The invoked program claims the accounts with default program id.
for post in program_output.post_states.iter_mut() { for post in program_output.post_states.iter_mut() {
if post.program_owner == DEFAULT_PROGRAM_ID { if post.program_owner == DEFAULT_PROGRAM_ID {
post.program_owner = program_id; post.program_owner = chained_call.program_id;
} }
} }
@ -147,37 +169,11 @@ impl PublicTransaction {
state_diff.insert(pre.account_id, post.clone()); state_diff.insert(pre.account_id, post.clone());
} }
if let Some(next_chained_call) = program_output.chained_call { for new_call in program_output.chained_calls.into_iter().rev() {
program_id = next_chained_call.program_id; chained_calls.push_front(new_call);
instruction_data = next_chained_call.instruction_data; }
// Build post states with metadata for next call chain_calls_counter += 1;
let mut post_states_with_metadata = Vec::new();
for (pre, post) in program_output
.pre_states
.iter()
.zip(program_output.post_states)
{
let mut post_with_metadata = pre.clone();
post_with_metadata.account = post.clone();
post_states_with_metadata.push(post_with_metadata);
}
input_pre_states = next_chained_call
.account_indices
.iter()
.map(|&i| {
post_states_with_metadata
.get(i)
.ok_or_else(|| {
NssaError::InvalidInput("Invalid account indices".into())
})
.cloned()
})
.collect::<Result<Vec<_>, NssaError>>()?;
} else {
break;
};
} }
Ok(state_diff) Ok(state_diff)

View File

@ -13,6 +13,8 @@ use crate::{
public_transaction::PublicTransaction, public_transaction::PublicTransaction,
}; };
pub const MAX_NUMBER_CHAINED_CALLS: usize = 10;
pub(crate) struct CommitmentSet { pub(crate) struct CommitmentSet {
merkle_tree: MerkleTree, merkle_tree: MerkleTree,
commitments: HashMap<Commitment, usize>, commitments: HashMap<Commitment, usize>,
@ -261,6 +263,7 @@ pub mod tests {
program::Program, program::Program,
public_transaction, public_transaction,
signature::PrivateKey, signature::PrivateKey,
state::MAX_NUMBER_CHAINED_CALLS,
}; };
fn transfer_transaction( fn transfer_transaction(
@ -2084,30 +2087,30 @@ pub mod tests {
} }
#[test] #[test]
fn test_chained_call() { fn test_chained_call_succeeds() {
let program = Program::chain_caller(); let program = Program::chain_caller();
let key = PrivateKey::try_new([1; 32]).unwrap(); let key = PrivateKey::try_new([1; 32]).unwrap();
let account_id = AccountId::from(&PublicKey::new_from_private_key(&key)); let from = AccountId::from(&PublicKey::new_from_private_key(&key));
let to = AccountId::new([2; 32]);
let initial_balance = 100; let initial_balance = 100;
let initial_data = [(account_id, initial_balance)]; let initial_data = [(from, initial_balance), (to, 0)];
let mut state = let mut state =
V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs(); V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs();
let from = account_id;
let from_key = key; let from_key = key;
let to = AccountId::new([2; 32]); let amount: u128 = 0;
let amount: u128 = 37; let instruction: (u128, ProgramId, u32) =
let instruction: (u128, ProgramId) = (amount, Program::authenticated_transfer_program().id(), 2);
(amount, Program::authenticated_transfer_program().id());
let expected_to_post = Account { let expected_to_post = Account {
program_owner: Program::chain_caller().id(), program_owner: Program::authenticated_transfer_program().id(),
balance: amount, balance: amount * 2, // The `chain_caller` chains the program twice
..Account::default() ..Account::default()
}; };
let message = public_transaction::Message::try_new( let message = public_transaction::Message::try_new(
program.id(), program.id(),
vec![to, from], // The chain_caller program permutes the account order in the call vec![to, from], // The chain_caller program permutes the account order in the chain
// call
vec![0], vec![0],
instruction, instruction,
) )
@ -2119,7 +2122,44 @@ pub mod tests {
let from_post = state.get_account_by_id(&from); let from_post = state.get_account_by_id(&from);
let to_post = state.get_account_by_id(&to); let to_post = state.get_account_by_id(&to);
assert_eq!(from_post.balance, initial_balance - amount); // The `chain_caller` program calls the program twice
assert_eq!(from_post.balance, initial_balance - 2 * amount);
assert_eq!(to_post, expected_to_post); assert_eq!(to_post, expected_to_post);
} }
#[test]
fn test_execution_fails_if_chained_calls_exceeds_depth() {
let program = Program::chain_caller();
let key = PrivateKey::try_new([1; 32]).unwrap();
let from = AccountId::from(&PublicKey::new_from_private_key(&key));
let to = AccountId::new([2; 32]);
let initial_balance = 100;
let initial_data = [(from, initial_balance), (to, 0)];
let mut state =
V02State::new_with_genesis_accounts(&initial_data, &[]).with_test_programs();
let from_key = key;
let amount: u128 = 0;
let instruction: (u128, ProgramId, u32) = (
amount,
Program::authenticated_transfer_program().id(),
MAX_NUMBER_CHAINED_CALLS as u32 + 1,
);
let message = public_transaction::Message::try_new(
program.id(),
vec![to, from], // The chain_caller program permutes the account order in the chain
// call
vec![0],
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[&from_key]);
let tx = PublicTransaction::new(message, witness_set);
let result = state.transition_from_public_transaction(&tx);
assert!(matches!(
result,
Err(NssaError::MaxChainedCallsDepthExceeded)
));
}
} }

View File

@ -3,14 +3,14 @@ use nssa_core::program::{
}; };
use risc0_zkvm::serde::to_vec; use risc0_zkvm::serde::to_vec;
type Instruction = (u128, ProgramId); type Instruction = (u128, ProgramId, u32);
/// A program that calls another program. /// A program that calls another program `num_chain_calls` times.
/// It permutes the order of the input accounts on the subsequent call /// It permutes the order of the input accounts on the subsequent call
fn main() { fn main() {
let ProgramInput { let ProgramInput {
pre_states, pre_states,
instruction: (balance, program_id), instruction: (balance, program_id, num_chain_calls),
} = read_nssa_inputs::<Instruction>(); } = read_nssa_inputs::<Instruction>();
let [sender_pre, receiver_pre] = match pre_states.try_into() { let [sender_pre, receiver_pre] = match pre_states.try_into() {
@ -20,10 +20,19 @@ fn main() {
let instruction_data = to_vec(&balance).unwrap(); let instruction_data = to_vec(&balance).unwrap();
let chained_call = Some(ChainedCall { let mut chained_call = vec![
ChainedCall {
program_id,
instruction_data: instruction_data.clone(),
pre_states: vec![receiver_pre.clone(), sender_pre.clone()], // <- Account order permutation here
};
num_chain_calls as usize - 1
];
chained_call.push(ChainedCall {
program_id, program_id,
instruction_data, instruction_data,
account_indices: vec![1, 0], // <- Account order permutation here pre_states: vec![receiver_pre.clone(), sender_pre.clone()], // <- Account order permutation here
}); });
write_nssa_outputs_with_chained_call( write_nssa_outputs_with_chained_call(