diff --git a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs index 5d13d46..888cb5b 100644 --- a/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/nssa/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -26,7 +26,7 @@ fn main() { let num_calls = program_outputs.len(); if num_calls > MAX_NUMBER_CHAINED_CALLS { - panic!("Max deapth is exceeded"); + panic!("Max depth is exceeded"); } if program_outputs[num_calls - 1].chained_call.is_some() { @@ -44,7 +44,7 @@ fn main() { } } - for program_output in program_outputs { + for (i, program_output) in program_outputs.iter().enumerate() { let mut program_output = program_output.clone(); // Check that `program_output` is consistent with the execution of the corresponding program. @@ -84,8 +84,8 @@ fn main() { if let Some(next_chained_call) = &program_output.chained_call { program_id = next_chained_call.program_id; - } else { - break; + } else if i != program_outputs.len() - 1 { + panic!("Inner call without a chained call found") }; } diff --git a/nssa/src/privacy_preserving_transaction/circuit.rs b/nssa/src/privacy_preserving_transaction/circuit.rs index 96cf583..e0df466 100644 --- a/nssa/src/privacy_preserving_transaction/circuit.rs +++ b/nssa/src/privacy_preserving_transaction/circuit.rs @@ -1,8 +1,10 @@ +use std::collections::HashMap; + use nssa_core::{ MembershipProof, NullifierPublicKey, NullifierSecretKey, PrivacyPreservingCircuitInput, PrivacyPreservingCircuitOutput, SharedSecretKey, account::AccountWithMetadata, - program::{InstructionData, ProgramOutput}, + program::{InstructionData, ProgramId, ProgramOutput}, }; use risc0_zkvm::{ExecutorEnv, InnerReceipt, Receipt, default_prover}; @@ -24,12 +26,16 @@ pub fn execute_and_prove( private_account_keys: &[(NullifierPublicKey, SharedSecretKey)], private_account_auth: &[(NullifierSecretKey, MembershipProof)], program: &Program, + programs: &HashMap, ) -> Result<(PrivacyPreservingCircuitOutput, Proof), NssaError> { + let mut program = program; + let mut instruction_data = instruction_data.clone(); + let mut pre_states = pre_states.to_vec(); let mut env_builder = ExecutorEnv::builder(); let mut program_outputs = Vec::new(); for _i in 0..MAX_NUMBER_CHAINED_CALLS { - let inner_receipt = execute_and_prove_program(program, pre_states, instruction_data)?; + let inner_receipt = execute_and_prove_program(program, &pre_states, &instruction_data)?; let program_output: ProgramOutput = inner_receipt .journal @@ -42,7 +48,33 @@ pub fn execute_and_prove( // Prove circuit. env_builder.add_assumption(inner_receipt); - if program_output.chained_call.is_none() { + if let Some(next_call) = program_output.chained_call { + // TODO: remove unwrap + program = programs.get(&next_call.program_id).unwrap(); + instruction_data = next_call.instruction_data.clone(); + // Build post states with metadata for next call + let mut post_states_with_metadata = Vec::new(); + for (pre, post) in program_output + .pre_states + .iter() + .zip(program_output.post_states) + { + let mut post_with_metadata = pre.clone(); + post_with_metadata.account = post.clone(); + post_states_with_metadata.push(post_with_metadata); + } + + pre_states = next_call + .account_indices + .iter() + .map(|&i| { + post_states_with_metadata + .get(i) + .ok_or_else(|| NssaError::InvalidInput("Invalid account indices".into())) + .cloned() + }) + .collect::, NssaError>>()?; + } else { break; } } diff --git a/nssa/src/state.rs b/nssa/src/state.rs index be3d0ab..0521aa7 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -2081,7 +2081,7 @@ pub mod tests { } #[test] - fn test_chained_call() { + fn test_public_chained_call() { let program = Program::chain_caller(); let key = PrivateKey::try_new([1; 32]).unwrap(); let address = Address::from(&PublicKey::new_from_private_key(&key)); @@ -2119,4 +2119,64 @@ pub mod tests { assert_eq!(from_post.balance, initial_balance - amount); assert_eq!(to_post, expected_to_post); } + + #[test] + fn test_private_chained_call() { + let program = Program::chain_caller(); + let from_keys = test_private_account_keys_1(); + let to_keys = test_private_account_keys_1(); + let initial_balance = 100; + let from_account = AccountWithMetadata::new( + Account { + program_owner: Program::authenticated_transfer_program().id(), + balance: initial_balance, + ..Account::default() + }, + true, + &from_keys.npk(), + ); + let to_account = AccountWithMetadata::new(Account::default(), true, &from_keys.npk()); + let from_commitment = Commitment::new(&from_keys.npk(), &from_account.account); + let mut state = V02State::new_with_genesis_accounts(&[], &[from_commitment.clone()]) + .with_test_programs(); + // let from = address; + // let from_key = key; + // let to = Address::new([2; 32]); + let amount: u128 = 37; + let instruction: (u128, ProgramId) = + (amount, Program::authenticated_transfer_program().id()); + + let from_esk = [3; 32]; + let from_ss = SharedSecretKey::new(&from_esk, &from_keys.ivk()); + let from_epk = EphemeralPublicKey::from_scalar(from_esk); + + let to_esk = [4; 32]; + let to_ss = SharedSecretKey::new(&to_esk, &to_keys.ivk()); + let to_epk = EphemeralPublicKey::from_scalar(to_esk); + + let (output, proof) = execute_and_prove( + &[from_account, to_account], + &Program::serialize_instruction(instruction).unwrap(), + &[1, 2], + &[0xdeadbeef1, 0xdeadbeef2], + &[(from_keys.npk(), from_ss), (to_keys.npk(), to_ss)], + &[( + from_keys.nsk, + state.get_proof_for_commitment(&from_commitment).unwrap(), + )], + &program, + ) + .unwrap(); + + let message = Message::try_from_circuit_output(vec![], vec![], vec![], output).unwrap(); + let witness_set = WitnessSet::for_message(&message, proof, &[]); + let tx = PrivacyPreservingTransaction::new(message, witness_set); + // + // state.transition_from_public_transaction(&tx).unwrap(); + // + // let from_post = state.get_account_by_address(&from); + // let to_post = state.get_account_by_address(&to); + // assert_eq!(from_post.balance, initial_balance - amount); + // assert_eq!(to_post, expected_to_post); + } }