diff --git a/nssa/core/src/account.rs b/nssa/core/src/account.rs index 51467af..55ab0de 100644 --- a/nssa/core/src/account.rs +++ b/nssa/core/src/account.rs @@ -15,9 +15,8 @@ pub type Nonce = u128; /// Account to be used both in public and private contexts #[derive( - Clone, Default, Eq, PartialEq, Serialize, Deserialize, BorshSerialize, BorshDeserialize, + Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, BorshSerialize, BorshDeserialize, )] -#[cfg_attr(any(feature = "host", test), derive(Debug))] pub struct Account { pub program_owner: ProgramId, pub balance: u128, @@ -25,8 +24,7 @@ pub struct Account { pub nonce: Nonce, } -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] -#[cfg_attr(any(feature = "host", test), derive(Debug))] +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub struct AccountWithMetadata { pub account: Account, pub is_authorized: bool, @@ -45,6 +43,7 @@ impl AccountWithMetadata { } #[derive( + Debug, Default, Copy, Clone, @@ -56,7 +55,7 @@ impl AccountWithMetadata { BorshSerialize, BorshDeserialize, )] -#[cfg_attr(any(feature = "host", test), derive(Debug, PartialOrd, Ord))] +#[cfg_attr(any(feature = "host", test), derive(PartialOrd, Ord))] pub struct AccountId { value: [u8; 32], } diff --git a/nssa/core/src/account/data.rs b/nssa/core/src/account/data.rs index 974cb06..396bbe6 100644 --- a/nssa/core/src/account/data.rs +++ b/nssa/core/src/account/data.rs @@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize}; pub const DATA_MAX_LENGTH_IN_BYTES: usize = 100 * 1024; // 100 KiB -#[derive(Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)] -#[cfg_attr(any(feature = "host", test), derive(Debug))] +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)] pub struct Data(Vec); impl Data { diff --git a/nssa/core/src/nullifier.rs b/nssa/core/src/nullifier.rs index ec30700..8d9d59f 100644 --- a/nssa/core/src/nullifier.rs +++ b/nssa/core/src/nullifier.rs @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize}; use crate::{Commitment, account::AccountId}; -#[derive(Serialize, Deserialize, PartialEq, Eq)] -#[cfg_attr(any(feature = "host", test), derive(Debug, Clone, Hash))] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[cfg_attr(any(feature = "host", test), derive(Clone, Hash))] pub struct NullifierPublicKey(pub [u8; 32]); impl From<&NullifierPublicKey> for AccountId { diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 357a4a5..2d1e03d 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -108,6 +108,11 @@ impl AccountPostState { pub fn account_mut(&mut self) -> &mut Account { &mut self.account } + + /// Consumes the post state and returns the underlying account + pub fn into_account(self) -> Account { + self.account + } } #[derive(Serialize, Deserialize, Clone)] diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 68437e4..93f947e 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -119,7 +119,7 @@ impl PublicTransaction { return Err(NssaError::MaxChainedCallsDepthExceeded); } - // Check the `program_id` corresponds to a deployed program + // Check that the `program_id` corresponds to a deployed program let Some(program) = state.programs().get(&chained_call.program_id) else { return Err(NssaError::InvalidInput("Unknown program".into())); }; @@ -136,11 +136,11 @@ impl PublicTransaction { ); let authorized_pdas = - self.compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds); + Self::compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds); for pre in &program_output.pre_states { let account_id = pre.account_id; - // Check that the program output pre_states coinicide with the values in the public + // Check that the program output pre_states coincide with the values in the public // state or with any modifications to those values during the chain of calls. let expected_pre = state_diff .get(&account_id) @@ -202,7 +202,6 @@ impl PublicTransaction { } fn compute_authorized_pdas( - &self, caller_program_id: &Option, pda_seeds: &[PdaSeed], ) -> HashSet { diff --git a/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/program_methods/guest/src/bin/privacy_preserving_circuit.rs index ffe4b13..2aeb940 100644 --- a/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -1,12 +1,18 @@ -use std::collections::HashMap; +use std::{ + collections::{HashMap, VecDeque}, + convert::Infallible, +}; use nssa_core::{ - Commitment, CommitmentSetDigest, DUMMY_COMMITMENT_HASH, EncryptionScheme, Nullifier, - NullifierPublicKey, PrivacyPreservingCircuitInput, PrivacyPreservingCircuitOutput, - account::{Account, AccountId, AccountWithMetadata}, + Commitment, CommitmentSetDigest, DUMMY_COMMITMENT_HASH, EncryptionScheme, MembershipProof, + Nullifier, NullifierPublicKey, NullifierSecretKey, PrivacyPreservingCircuitInput, + PrivacyPreservingCircuitOutput, SharedSecretKey, + account::{Account, AccountId, AccountWithMetadata, Nonce}, compute_digest_for_path, - encryption::Ciphertext, - program::{DEFAULT_PROGRAM_ID, MAX_NUMBER_CHAINED_CALLS, validate_execution}, + program::{ + ChainedCall, DEFAULT_PROGRAM_ID, MAX_NUMBER_CHAINED_CALLS, ProgramId, ProgramOutput, + validate_execution, + }, }; use risc0_zkvm::{guest::env, serde::to_vec}; @@ -18,118 +24,172 @@ fn main() { private_account_keys, private_account_nsks, private_account_membership_proofs, - mut program_id, + program_id, } = env::read(); - let mut pre_states: Vec = Vec::new(); - let mut state_diff: HashMap = HashMap::new(); + let execution_state = ExecutionState::derive_from_outputs(program_id, program_outputs); - let num_calls = program_outputs.len(); - if num_calls > MAX_NUMBER_CHAINED_CALLS { - panic!("Max chained calls depth is exceeded"); + let output = compute_circuit_output( + execution_state, + &visibility_mask, + &private_account_nonces, + &private_account_keys, + &private_account_nsks, + &private_account_membership_proofs, + ); + + env::commit(&output); +} + +/// World state before and after program execution. +struct ExecutionState { + pre_states: Vec, + post_states: HashMap, +} + +impl ExecutionState { + /// Validate program outputs and derive the overall execution state. + pub fn derive_from_outputs(program_id: ProgramId, program_outputs: Vec) -> Self { + let Some(first_output) = program_outputs.first() else { + panic!("Program outputs is empty") + }; + + let initial_call = ChainedCall { + program_id, + instruction_data: first_output.instruction_data.clone(), + pre_states: first_output.pre_states.clone(), + pda_seeds: Vec::new(), + }; + let mut chained_calls = VecDeque::from_iter([initial_call]); + + let mut execution_state = ExecutionState { + pre_states: Vec::new(), + post_states: HashMap::new(), + }; + let mut last_program_id = program_id; + let mut program_outputs_iter = program_outputs.into_iter(); + let mut chain_calls_counter = 0; + while let Some(chained_call) = chained_calls.pop_front() { + assert!( + chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS, + "Max chained calls depth is exceeded" + ); + + let Some(program_output) = program_outputs_iter.next() else { + panic!("Insufficient program outputs for chained calls"); + }; + + // Check that instruction data in chained call is the instruction data in program output + assert_eq!( + chained_call.instruction_data, program_output.instruction_data, + "Mismatched instruction data between chained call and program output" + ); + + // Check that `program_output` is consistent with the execution of the corresponding + // program. + let program_output_words = + &to_vec(&program_output).expect("program_output must be serializable"); + env::verify(chained_call.program_id, program_output_words).unwrap_or_else( + |_: Infallible| unreachable!("Infallible error is never constructed"), + ); + + // TODO: Why private execution doesn't care about public account authorization? + + // Check that the program is well behaved. + // See the # Programs section for the definition of the `validate_execution` method. + let execution_valid = validate_execution( + &program_output.pre_states, + &program_output.post_states, + chained_call.program_id, + ); + assert!(execution_valid, "Bad behaved program"); + + for next_call in program_output.chained_calls.iter().rev() { + chained_calls.push_front(next_call.clone()); + } + + execution_state.populate_from_output(chained_call.program_id, program_output); + last_program_id = chained_call.program_id; + chain_calls_counter += 1; + } + + assert!( + program_outputs_iter.next().is_none(), + "Inner call without a chained call found", + ); + + // Claim accounts + for account in execution_state.post_states.values_mut() { + if account.program_owner == DEFAULT_PROGRAM_ID { + account.program_owner = last_program_id; + } + } + + execution_state } - let Some(last_program_call) = program_outputs.last() else { - panic!("Program outputs is empty") + fn populate_from_output(&mut self, program_id: ProgramId, program_output: ProgramOutput) { + for (pre, mut post) in program_output + .pre_states + .into_iter() + .zip(program_output.post_states) + { + let pre_account_id = pre.account_id; + if let Some(account_pre) = self.post_states.get(&pre_account_id) { + assert_eq!(account_pre, &pre.account, "Inconsistent pre state"); + } else { + self.pre_states.push(pre); + } + + if post.requires_claim() { + // The invoked program can only claim accounts with default program id. + if post.account().program_owner == DEFAULT_PROGRAM_ID { + post.account_mut().program_owner = program_id; + } else { + panic!("Cannot claim an initialized account") + } + } + + self.post_states.insert(pre_account_id, post.into_account()); + } + } + + /// Get an iterator over pre and post states of each account involved in the execution. + pub fn into_states_iter( + mut self, + ) -> impl ExactSizeIterator { + self.pre_states.into_iter().map(move |pre| { + let post = self + .post_states + .remove(&pre.account_id) + .expect("Account from pre states should exist in state diff"); + (pre, post) + }) + } +} + +fn compute_circuit_output( + execution_state: ExecutionState, + visibility_mask: &[u8], + private_account_nonces: &[Nonce], + private_account_keys: &[(NullifierPublicKey, SharedSecretKey)], + private_account_nsks: &[NullifierSecretKey], + private_account_membership_proofs: &[Option], +) -> PrivacyPreservingCircuitOutput { + let mut output = PrivacyPreservingCircuitOutput { + public_pre_states: Vec::new(), + public_post_states: Vec::new(), + ciphertexts: Vec::new(), + new_commitments: Vec::new(), + new_nullifiers: Vec::new(), }; - if !last_program_call.chained_calls.is_empty() { - panic!("Call stack is incomplete"); - } - - for window in program_outputs.windows(2) { - let caller = &window[0]; - let callee = &window[1]; - - if caller.chained_calls.len() > 1 { - panic!("Privacy Multi-chained calls are not supported yet"); - } - - // TODO: Modify when multi-chain calls are supported in the circuit - let Some(caller_chained_call) = &caller.chained_calls.first() else { - panic!("Expected chained call"); - }; - - // Check that instruction data in caller is the instruction data in callee - if caller_chained_call.instruction_data != callee.instruction_data { - panic!("Invalid instruction data"); - } - - // Check that account pre_states in caller are the ones in calle - if caller_chained_call.pre_states != callee.pre_states { - panic!("Invalid pre states"); - } - } - - for (i, program_output) in program_outputs.iter().enumerate() { - let mut program_output = program_output.clone(); - - // Check that `program_output` is consistent with the execution of the corresponding - // program. - let program_output_words = - &to_vec(&program_output).expect("program_output must be serializable"); - env::verify(program_id, program_output_words) - .expect("program output must match the program's execution"); - - // Check that the program is well behaved. - // See the # Programs section for the definition of the `validate_execution` method. - if !validate_execution( - &program_output.pre_states, - &program_output.post_states, - program_id, - ) { - panic!("Bad behaved program"); - } - - // The invoked program claims the accounts with default program id. - for post in program_output - .post_states - .iter_mut() - .filter(|post| post.requires_claim()) - { - // The invoked program can only claim accounts with default program id. - if post.account().program_owner == DEFAULT_PROGRAM_ID { - post.account_mut().program_owner = program_id; - } else { - panic!("Cannot claim an initialized account") - } - } - - for (pre, post) in program_output - .pre_states - .iter() - .zip(&program_output.post_states) - { - if let Some(account_pre) = state_diff.get(&pre.account_id) { - if account_pre != &pre.account { - panic!("Invalid input"); - } - } else { - pre_states.push(pre.clone()); - } - state_diff.insert(pre.account_id, post.account().clone()); - } - - // TODO: Modify when multi-chain calls are supported in the circuit - if let Some(next_chained_call) = &program_output.chained_calls.first() { - program_id = next_chained_call.program_id; - } else if i != program_outputs.len() - 1 { - panic!("Inner call without a chained call found") - }; - } - - let n_accounts = pre_states.len(); - if visibility_mask.len() != n_accounts { - panic!("Invalid visibility mask length"); - } - - // These lists will be the public outputs of this circuit - // and will be populated next. - let mut public_pre_states: Vec = Vec::new(); - let mut public_post_states: Vec = Vec::new(); - let mut ciphertexts: Vec = Vec::new(); - let mut new_commitments: Vec = Vec::new(); - let mut new_nullifiers: Vec<(Nullifier, CommitmentSetDigest)> = Vec::new(); + let states_iter = execution_state.into_states_iter(); + assert_eq!( + visibility_mask.len(), + states_iter.len(), + "Invalid visibility mask length" + ); let mut private_nonces_iter = private_account_nonces.iter(); let mut private_keys_iter = private_account_keys.iter(); @@ -137,141 +197,156 @@ fn main() { let mut private_membership_proofs_iter = private_account_membership_proofs.iter(); let mut output_index = 0; - for i in 0..n_accounts { - match visibility_mask[i] { + for (visibility_mask, (pre_state, post_state)) in + visibility_mask.iter().copied().zip(states_iter) + { + match visibility_mask { 0 => { // Public account - public_pre_states.push(pre_states[i].clone()); - - let mut post = state_diff.get(&pre_states[i].account_id).unwrap().clone(); - - if post.program_owner == DEFAULT_PROGRAM_ID { - // Claim account - post.program_owner = program_id; - } - public_post_states.push(post); + output.public_pre_states.push(pre_state); + output.public_post_states.push(post_state); } 1 | 2 => { - let new_nonce = private_nonces_iter.next().expect("Missing private nonce"); - let (npk, shared_secret) = private_keys_iter.next().expect("Missing keys"); + let Some((npk, shared_secret)) = private_keys_iter.next() else { + panic!("Missing private account key"); + }; - if AccountId::from(npk) != pre_states[i].account_id { - panic!("AccountId mismatch"); - } + assert_eq!( + AccountId::from(npk), + pre_state.account_id, + "AccountId mismatch" + ); - if visibility_mask[i] == 1 { + let new_nullifier = if visibility_mask == 1 { // Private account with authentication - let nsk = private_nsks_iter.next().expect("Missing nsk"); + + let Some(nsk) = private_nsks_iter.next() else { + panic!("Missing private account nullifier secret key"); + }; // Verify the nullifier public key - let expected_npk = NullifierPublicKey::from(nsk); - if &expected_npk != npk { - panic!("Nullifier public key mismatch"); - } + assert_eq!( + npk, + &NullifierPublicKey::from(nsk), + "Nullifier public key mismatch" + ); // Check pre_state authorization - if !pre_states[i].is_authorized { - panic!("Pre-state not authorized"); - } + assert!( + pre_state.is_authorized, + "Pre-state not authorized for authenticated private account" + ); - let membership_proof_opt = private_membership_proofs_iter - .next() - .expect("Missing membership proof"); - let (nullifier, set_digest) = membership_proof_opt - .as_ref() - .map(|membership_proof| { - // Compute commitment set digest associated with provided auth path - let commitment_pre = Commitment::new(npk, &pre_states[i].account); - let set_digest = - compute_digest_for_path(&commitment_pre, membership_proof); + let Some(membership_proof_opt) = private_membership_proofs_iter.next() else { + panic!("Missing membership proof"); + }; - // Compute update nullifier - let nullifier = Nullifier::for_account_update(&commitment_pre, nsk); - (nullifier, set_digest) - }) - .unwrap_or_else(|| { - if pre_states[i].account != Account::default() { - panic!("Found new private account with non default values."); - } - - // Compute initialization nullifier - let nullifier = Nullifier::for_account_initialization(npk); - (nullifier, DUMMY_COMMITMENT_HASH) - }); - new_nullifiers.push((nullifier, set_digest)); + compute_nullifier_and_set_digest( + membership_proof_opt.as_ref(), + &pre_state.account, + npk, + nsk, + ) } else { // Private account without authentication - if pre_states[i].account != Account::default() { - panic!("Found new private account with non default values."); - } - if pre_states[i].is_authorized { - panic!("Found new private account marked as authorized."); - } + assert_eq!( + pre_state.account, + Account::default(), + "Found new private account with non default values", + ); + + assert!( + !pre_state.is_authorized, + "Found new private account marked as authorized." + ); + + let Some(membership_proof_opt) = private_membership_proofs_iter.next() else { + panic!("Missing membership proof"); + }; - let membership_proof_opt = private_membership_proofs_iter - .next() - .expect("Missing membership proof"); assert!( membership_proof_opt.is_none(), "Membership proof must be None for unauthorized accounts" ); + let nullifier = Nullifier::for_account_initialization(npk); - new_nullifiers.push((nullifier, DUMMY_COMMITMENT_HASH)); - } + (nullifier, DUMMY_COMMITMENT_HASH) + }; + output.new_nullifiers.push(new_nullifier); // Update post-state with new nonce - let mut post_with_updated_values = - state_diff.get(&pre_states[i].account_id).unwrap().clone(); - post_with_updated_values.nonce = *new_nonce; - - if post_with_updated_values.program_owner == DEFAULT_PROGRAM_ID { - // Claim account - post_with_updated_values.program_owner = program_id; - } + let mut post_with_updated_nonce = post_state; + let Some(new_nonce) = private_nonces_iter.next() else { + panic!("Missing private account nonce"); + }; + post_with_updated_nonce.nonce = *new_nonce; // Compute commitment - let commitment_post = Commitment::new(npk, &post_with_updated_values); + let commitment_post = Commitment::new(npk, &post_with_updated_nonce); // Encrypt and push post state let encrypted_account = EncryptionScheme::encrypt( - &post_with_updated_values, + &post_with_updated_nonce, shared_secret, &commitment_post, output_index, ); - new_commitments.push(commitment_post); - ciphertexts.push(encrypted_account); + output.new_commitments.push(commitment_post); + output.ciphertexts.push(encrypted_account); output_index += 1; } _ => panic!("Invalid visibility mask value"), } } - if private_nonces_iter.next().is_some() { - panic!("Too many nonces"); - } + assert!(private_nonces_iter.next().is_none(), "Too many nonces"); - if private_keys_iter.next().is_some() { - panic!("Too many private account keys"); - } + assert!( + private_keys_iter.next().is_none(), + "Too many private account keys" + ); - if private_nsks_iter.next().is_some() { - panic!("Too many private account authentication keys"); - } + assert!( + private_nsks_iter.next().is_none(), + "Too many private account nullifier secret keys" + ); - if private_membership_proofs_iter.next().is_some() { - panic!("Too many private account membership proofs"); - } + assert!( + private_membership_proofs_iter.next().is_none(), + "Too many private account membership proofs" + ); - let output = PrivacyPreservingCircuitOutput { - public_pre_states, - public_post_states, - ciphertexts, - new_commitments, - new_nullifiers, - }; - - env::commit(&output); + output +} + +fn compute_nullifier_and_set_digest( + membership_proof_opt: Option<&MembershipProof>, + pre_account: &Account, + npk: &NullifierPublicKey, + nsk: &NullifierSecretKey, +) -> (Nullifier, CommitmentSetDigest) { + membership_proof_opt + .as_ref() + .map(|membership_proof| { + // Compute commitment set digest associated with provided auth path + let commitment_pre = Commitment::new(npk, pre_account); + let set_digest = compute_digest_for_path(&commitment_pre, membership_proof); + + // Compute update nullifier + let nullifier = Nullifier::for_account_update(&commitment_pre, nsk); + (nullifier, set_digest) + }) + .unwrap_or_else(|| { + assert_eq!( + *pre_account, + Account::default(), + "Found new private account with non default values" + ); + + // Compute initialization nullifier + let nullifier = Nullifier::for_account_initialization(npk); + (nullifier, DUMMY_COMMITMENT_HASH) + }) }