feat: implement private multi chain calls in circuit

This commit is contained in:
Daniil Polyakov 2025-12-24 22:58:33 +03:00
parent 1d09afd9e0
commit 847bd1a376
6 changed files with 294 additions and 217 deletions

View File

@ -15,9 +15,8 @@ pub type Nonce = u128;
/// Account to be used both in public and private contexts /// Account to be used both in public and private contexts
#[derive( #[derive(
Clone, Default, Eq, PartialEq, Serialize, Deserialize, BorshSerialize, BorshDeserialize, Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)] )]
#[cfg_attr(any(feature = "host", test), derive(Debug))]
pub struct Account { pub struct Account {
pub program_owner: ProgramId, pub program_owner: ProgramId,
pub balance: u128, pub balance: u128,
@ -25,8 +24,7 @@ pub struct Account {
pub nonce: Nonce, pub nonce: Nonce,
} }
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
#[cfg_attr(any(feature = "host", test), derive(Debug))]
pub struct AccountWithMetadata { pub struct AccountWithMetadata {
pub account: Account, pub account: Account,
pub is_authorized: bool, pub is_authorized: bool,
@ -45,6 +43,7 @@ impl AccountWithMetadata {
} }
#[derive( #[derive(
Debug,
Default, Default,
Copy, Copy,
Clone, Clone,
@ -56,7 +55,7 @@ impl AccountWithMetadata {
BorshSerialize, BorshSerialize,
BorshDeserialize, BorshDeserialize,
)] )]
#[cfg_attr(any(feature = "host", test), derive(Debug, PartialOrd, Ord))] #[cfg_attr(any(feature = "host", test), derive(PartialOrd, Ord))]
pub struct AccountId { pub struct AccountId {
value: [u8; 32], value: [u8; 32],
} }

View File

@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize};
pub const DATA_MAX_LENGTH_IN_BYTES: usize = 100 * 1024; // 100 KiB pub const DATA_MAX_LENGTH_IN_BYTES: usize = 100 * 1024; // 100 KiB
#[derive(Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)] #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)]
#[cfg_attr(any(feature = "host", test), derive(Debug))]
pub struct Data(Vec<u8>); pub struct Data(Vec<u8>);
impl Data { impl Data {

View File

@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use crate::{Commitment, account::AccountId}; use crate::{Commitment, account::AccountId};
#[derive(Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(any(feature = "host", test), derive(Debug, Clone, Hash))] #[cfg_attr(any(feature = "host", test), derive(Clone, Hash))]
pub struct NullifierPublicKey(pub [u8; 32]); pub struct NullifierPublicKey(pub [u8; 32]);
impl From<&NullifierPublicKey> for AccountId { impl From<&NullifierPublicKey> for AccountId {

View File

@ -108,6 +108,11 @@ impl AccountPostState {
pub fn account_mut(&mut self) -> &mut Account { pub fn account_mut(&mut self) -> &mut Account {
&mut self.account &mut self.account
} }
/// Consumes the post state and returns the underlying account
pub fn into_account(self) -> Account {
self.account
}
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]

View File

@ -119,7 +119,7 @@ impl PublicTransaction {
return Err(NssaError::MaxChainedCallsDepthExceeded); return Err(NssaError::MaxChainedCallsDepthExceeded);
} }
// Check the `program_id` corresponds to a deployed program // Check that the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else { let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into())); return Err(NssaError::InvalidInput("Unknown program".into()));
}; };
@ -136,11 +136,11 @@ impl PublicTransaction {
); );
let authorized_pdas = let authorized_pdas =
self.compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds); Self::compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds);
for pre in &program_output.pre_states { for pre in &program_output.pre_states {
let account_id = pre.account_id; let account_id = pre.account_id;
// Check that the program output pre_states coinicide with the values in the public // Check that the program output pre_states coincide with the values in the public
// state or with any modifications to those values during the chain of calls. // state or with any modifications to those values during the chain of calls.
let expected_pre = state_diff let expected_pre = state_diff
.get(&account_id) .get(&account_id)
@ -202,7 +202,6 @@ impl PublicTransaction {
} }
fn compute_authorized_pdas( fn compute_authorized_pdas(
&self,
caller_program_id: &Option<ProgramId>, caller_program_id: &Option<ProgramId>,
pda_seeds: &[PdaSeed], pda_seeds: &[PdaSeed],
) -> HashSet<AccountId> { ) -> HashSet<AccountId> {

View File

@ -1,12 +1,18 @@
use std::collections::HashMap; use std::{
collections::{HashMap, VecDeque},
convert::Infallible,
};
use nssa_core::{ use nssa_core::{
Commitment, CommitmentSetDigest, DUMMY_COMMITMENT_HASH, EncryptionScheme, Nullifier, Commitment, CommitmentSetDigest, DUMMY_COMMITMENT_HASH, EncryptionScheme, MembershipProof,
NullifierPublicKey, PrivacyPreservingCircuitInput, PrivacyPreservingCircuitOutput, Nullifier, NullifierPublicKey, NullifierSecretKey, PrivacyPreservingCircuitInput,
account::{Account, AccountId, AccountWithMetadata}, PrivacyPreservingCircuitOutput, SharedSecretKey,
account::{Account, AccountId, AccountWithMetadata, Nonce},
compute_digest_for_path, compute_digest_for_path,
encryption::Ciphertext, program::{
program::{DEFAULT_PROGRAM_ID, MAX_NUMBER_CHAINED_CALLS, validate_execution}, ChainedCall, DEFAULT_PROGRAM_ID, MAX_NUMBER_CHAINED_CALLS, ProgramId, ProgramOutput,
validate_execution,
},
}; };
use risc0_zkvm::{guest::env, serde::to_vec}; use risc0_zkvm::{guest::env, serde::to_vec};
@ -18,118 +24,172 @@ fn main() {
private_account_keys, private_account_keys,
private_account_nsks, private_account_nsks,
private_account_membership_proofs, private_account_membership_proofs,
mut program_id, program_id,
} = env::read(); } = env::read();
let mut pre_states: Vec<AccountWithMetadata> = Vec::new(); let execution_state = ExecutionState::derive_from_outputs(program_id, program_outputs);
let mut state_diff: HashMap<AccountId, Account> = HashMap::new();
let num_calls = program_outputs.len(); let output = compute_circuit_output(
if num_calls > MAX_NUMBER_CHAINED_CALLS { execution_state,
panic!("Max chained calls depth is exceeded"); &visibility_mask,
&private_account_nonces,
&private_account_keys,
&private_account_nsks,
&private_account_membership_proofs,
);
env::commit(&output);
}
/// World state before and after program execution.
struct ExecutionState {
pre_states: Vec<AccountWithMetadata>,
post_states: HashMap<AccountId, Account>,
}
impl ExecutionState {
/// Validate program outputs and derive the overall execution state.
pub fn derive_from_outputs(program_id: ProgramId, program_outputs: Vec<ProgramOutput>) -> Self {
let Some(first_output) = program_outputs.first() else {
panic!("Program outputs is empty")
};
let initial_call = ChainedCall {
program_id,
instruction_data: first_output.instruction_data.clone(),
pre_states: first_output.pre_states.clone(),
pda_seeds: Vec::new(),
};
let mut chained_calls = VecDeque::from_iter([initial_call]);
let mut execution_state = ExecutionState {
pre_states: Vec::new(),
post_states: HashMap::new(),
};
let mut last_program_id = program_id;
let mut program_outputs_iter = program_outputs.into_iter();
let mut chain_calls_counter = 0;
while let Some(chained_call) = chained_calls.pop_front() {
assert!(
chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS,
"Max chained calls depth is exceeded"
);
let Some(program_output) = program_outputs_iter.next() else {
panic!("Insufficient program outputs for chained calls");
};
// Check that instruction data in chained call is the instruction data in program output
assert_eq!(
chained_call.instruction_data, program_output.instruction_data,
"Mismatched instruction data between chained call and program output"
);
// Check that `program_output` is consistent with the execution of the corresponding
// program.
let program_output_words =
&to_vec(&program_output).expect("program_output must be serializable");
env::verify(chained_call.program_id, program_output_words).unwrap_or_else(
|_: Infallible| unreachable!("Infallible error is never constructed"),
);
// TODO: Why private execution doesn't care about public account authorization?
// Check that the program is well behaved.
// See the # Programs section for the definition of the `validate_execution` method.
let execution_valid = validate_execution(
&program_output.pre_states,
&program_output.post_states,
chained_call.program_id,
);
assert!(execution_valid, "Bad behaved program");
for next_call in program_output.chained_calls.iter().rev() {
chained_calls.push_front(next_call.clone());
}
execution_state.populate_from_output(chained_call.program_id, program_output);
last_program_id = chained_call.program_id;
chain_calls_counter += 1;
}
assert!(
program_outputs_iter.next().is_none(),
"Inner call without a chained call found",
);
// Claim accounts
for account in execution_state.post_states.values_mut() {
if account.program_owner == DEFAULT_PROGRAM_ID {
account.program_owner = last_program_id;
}
}
execution_state
} }
let Some(last_program_call) = program_outputs.last() else { fn populate_from_output(&mut self, program_id: ProgramId, program_output: ProgramOutput) {
panic!("Program outputs is empty") for (pre, mut post) in program_output
.pre_states
.into_iter()
.zip(program_output.post_states)
{
let pre_account_id = pre.account_id;
if let Some(account_pre) = self.post_states.get(&pre_account_id) {
assert_eq!(account_pre, &pre.account, "Inconsistent pre state");
} else {
self.pre_states.push(pre);
}
if post.requires_claim() {
// The invoked program can only claim accounts with default program id.
if post.account().program_owner == DEFAULT_PROGRAM_ID {
post.account_mut().program_owner = program_id;
} else {
panic!("Cannot claim an initialized account")
}
}
self.post_states.insert(pre_account_id, post.into_account());
}
}
/// Get an iterator over pre and post states of each account involved in the execution.
pub fn into_states_iter(
mut self,
) -> impl ExactSizeIterator<Item = (AccountWithMetadata, Account)> {
self.pre_states.into_iter().map(move |pre| {
let post = self
.post_states
.remove(&pre.account_id)
.expect("Account from pre states should exist in state diff");
(pre, post)
})
}
}
fn compute_circuit_output(
execution_state: ExecutionState,
visibility_mask: &[u8],
private_account_nonces: &[Nonce],
private_account_keys: &[(NullifierPublicKey, SharedSecretKey)],
private_account_nsks: &[NullifierSecretKey],
private_account_membership_proofs: &[Option<MembershipProof>],
) -> PrivacyPreservingCircuitOutput {
let mut output = PrivacyPreservingCircuitOutput {
public_pre_states: Vec::new(),
public_post_states: Vec::new(),
ciphertexts: Vec::new(),
new_commitments: Vec::new(),
new_nullifiers: Vec::new(),
}; };
if !last_program_call.chained_calls.is_empty() { let states_iter = execution_state.into_states_iter();
panic!("Call stack is incomplete"); assert_eq!(
} visibility_mask.len(),
states_iter.len(),
for window in program_outputs.windows(2) { "Invalid visibility mask length"
let caller = &window[0]; );
let callee = &window[1];
if caller.chained_calls.len() > 1 {
panic!("Privacy Multi-chained calls are not supported yet");
}
// TODO: Modify when multi-chain calls are supported in the circuit
let Some(caller_chained_call) = &caller.chained_calls.first() else {
panic!("Expected chained call");
};
// Check that instruction data in caller is the instruction data in callee
if caller_chained_call.instruction_data != callee.instruction_data {
panic!("Invalid instruction data");
}
// Check that account pre_states in caller are the ones in calle
if caller_chained_call.pre_states != callee.pre_states {
panic!("Invalid pre states");
}
}
for (i, program_output) in program_outputs.iter().enumerate() {
let mut program_output = program_output.clone();
// Check that `program_output` is consistent with the execution of the corresponding
// program.
let program_output_words =
&to_vec(&program_output).expect("program_output must be serializable");
env::verify(program_id, program_output_words)
.expect("program output must match the program's execution");
// Check that the program is well behaved.
// See the # Programs section for the definition of the `validate_execution` method.
if !validate_execution(
&program_output.pre_states,
&program_output.post_states,
program_id,
) {
panic!("Bad behaved program");
}
// The invoked program claims the accounts with default program id.
for post in program_output
.post_states
.iter_mut()
.filter(|post| post.requires_claim())
{
// The invoked program can only claim accounts with default program id.
if post.account().program_owner == DEFAULT_PROGRAM_ID {
post.account_mut().program_owner = program_id;
} else {
panic!("Cannot claim an initialized account")
}
}
for (pre, post) in program_output
.pre_states
.iter()
.zip(&program_output.post_states)
{
if let Some(account_pre) = state_diff.get(&pre.account_id) {
if account_pre != &pre.account {
panic!("Invalid input");
}
} else {
pre_states.push(pre.clone());
}
state_diff.insert(pre.account_id, post.account().clone());
}
// TODO: Modify when multi-chain calls are supported in the circuit
if let Some(next_chained_call) = &program_output.chained_calls.first() {
program_id = next_chained_call.program_id;
} else if i != program_outputs.len() - 1 {
panic!("Inner call without a chained call found")
};
}
let n_accounts = pre_states.len();
if visibility_mask.len() != n_accounts {
panic!("Invalid visibility mask length");
}
// These lists will be the public outputs of this circuit
// and will be populated next.
let mut public_pre_states: Vec<AccountWithMetadata> = Vec::new();
let mut public_post_states: Vec<Account> = Vec::new();
let mut ciphertexts: Vec<Ciphertext> = Vec::new();
let mut new_commitments: Vec<Commitment> = Vec::new();
let mut new_nullifiers: Vec<(Nullifier, CommitmentSetDigest)> = Vec::new();
let mut private_nonces_iter = private_account_nonces.iter(); let mut private_nonces_iter = private_account_nonces.iter();
let mut private_keys_iter = private_account_keys.iter(); let mut private_keys_iter = private_account_keys.iter();
@ -137,141 +197,156 @@ fn main() {
let mut private_membership_proofs_iter = private_account_membership_proofs.iter(); let mut private_membership_proofs_iter = private_account_membership_proofs.iter();
let mut output_index = 0; let mut output_index = 0;
for i in 0..n_accounts { for (visibility_mask, (pre_state, post_state)) in
match visibility_mask[i] { visibility_mask.iter().copied().zip(states_iter)
{
match visibility_mask {
0 => { 0 => {
// Public account // Public account
public_pre_states.push(pre_states[i].clone()); output.public_pre_states.push(pre_state);
output.public_post_states.push(post_state);
let mut post = state_diff.get(&pre_states[i].account_id).unwrap().clone();
if post.program_owner == DEFAULT_PROGRAM_ID {
// Claim account
post.program_owner = program_id;
}
public_post_states.push(post);
} }
1 | 2 => { 1 | 2 => {
let new_nonce = private_nonces_iter.next().expect("Missing private nonce"); let Some((npk, shared_secret)) = private_keys_iter.next() else {
let (npk, shared_secret) = private_keys_iter.next().expect("Missing keys"); panic!("Missing private account key");
};
if AccountId::from(npk) != pre_states[i].account_id { assert_eq!(
panic!("AccountId mismatch"); AccountId::from(npk),
} pre_state.account_id,
"AccountId mismatch"
);
if visibility_mask[i] == 1 { let new_nullifier = if visibility_mask == 1 {
// Private account with authentication // Private account with authentication
let nsk = private_nsks_iter.next().expect("Missing nsk");
let Some(nsk) = private_nsks_iter.next() else {
panic!("Missing private account nullifier secret key");
};
// Verify the nullifier public key // Verify the nullifier public key
let expected_npk = NullifierPublicKey::from(nsk); assert_eq!(
if &expected_npk != npk { npk,
panic!("Nullifier public key mismatch"); &NullifierPublicKey::from(nsk),
} "Nullifier public key mismatch"
);
// Check pre_state authorization // Check pre_state authorization
if !pre_states[i].is_authorized { assert!(
panic!("Pre-state not authorized"); pre_state.is_authorized,
} "Pre-state not authorized for authenticated private account"
);
let membership_proof_opt = private_membership_proofs_iter let Some(membership_proof_opt) = private_membership_proofs_iter.next() else {
.next() panic!("Missing membership proof");
.expect("Missing membership proof"); };
let (nullifier, set_digest) = membership_proof_opt
.as_ref()
.map(|membership_proof| {
// Compute commitment set digest associated with provided auth path
let commitment_pre = Commitment::new(npk, &pre_states[i].account);
let set_digest =
compute_digest_for_path(&commitment_pre, membership_proof);
// Compute update nullifier compute_nullifier_and_set_digest(
let nullifier = Nullifier::for_account_update(&commitment_pre, nsk); membership_proof_opt.as_ref(),
(nullifier, set_digest) &pre_state.account,
}) npk,
.unwrap_or_else(|| { nsk,
if pre_states[i].account != Account::default() { )
panic!("Found new private account with non default values.");
}
// Compute initialization nullifier
let nullifier = Nullifier::for_account_initialization(npk);
(nullifier, DUMMY_COMMITMENT_HASH)
});
new_nullifiers.push((nullifier, set_digest));
} else { } else {
// Private account without authentication // Private account without authentication
if pre_states[i].account != Account::default() {
panic!("Found new private account with non default values.");
}
if pre_states[i].is_authorized { assert_eq!(
panic!("Found new private account marked as authorized."); pre_state.account,
} Account::default(),
"Found new private account with non default values",
);
assert!(
!pre_state.is_authorized,
"Found new private account marked as authorized."
);
let Some(membership_proof_opt) = private_membership_proofs_iter.next() else {
panic!("Missing membership proof");
};
let membership_proof_opt = private_membership_proofs_iter
.next()
.expect("Missing membership proof");
assert!( assert!(
membership_proof_opt.is_none(), membership_proof_opt.is_none(),
"Membership proof must be None for unauthorized accounts" "Membership proof must be None for unauthorized accounts"
); );
let nullifier = Nullifier::for_account_initialization(npk); let nullifier = Nullifier::for_account_initialization(npk);
new_nullifiers.push((nullifier, DUMMY_COMMITMENT_HASH)); (nullifier, DUMMY_COMMITMENT_HASH)
} };
output.new_nullifiers.push(new_nullifier);
// Update post-state with new nonce // Update post-state with new nonce
let mut post_with_updated_values = let mut post_with_updated_nonce = post_state;
state_diff.get(&pre_states[i].account_id).unwrap().clone(); let Some(new_nonce) = private_nonces_iter.next() else {
post_with_updated_values.nonce = *new_nonce; panic!("Missing private account nonce");
};
if post_with_updated_values.program_owner == DEFAULT_PROGRAM_ID { post_with_updated_nonce.nonce = *new_nonce;
// Claim account
post_with_updated_values.program_owner = program_id;
}
// Compute commitment // Compute commitment
let commitment_post = Commitment::new(npk, &post_with_updated_values); let commitment_post = Commitment::new(npk, &post_with_updated_nonce);
// Encrypt and push post state // Encrypt and push post state
let encrypted_account = EncryptionScheme::encrypt( let encrypted_account = EncryptionScheme::encrypt(
&post_with_updated_values, &post_with_updated_nonce,
shared_secret, shared_secret,
&commitment_post, &commitment_post,
output_index, output_index,
); );
new_commitments.push(commitment_post); output.new_commitments.push(commitment_post);
ciphertexts.push(encrypted_account); output.ciphertexts.push(encrypted_account);
output_index += 1; output_index += 1;
} }
_ => panic!("Invalid visibility mask value"), _ => panic!("Invalid visibility mask value"),
} }
} }
if private_nonces_iter.next().is_some() { assert!(private_nonces_iter.next().is_none(), "Too many nonces");
panic!("Too many nonces");
}
if private_keys_iter.next().is_some() { assert!(
panic!("Too many private account keys"); private_keys_iter.next().is_none(),
} "Too many private account keys"
);
if private_nsks_iter.next().is_some() { assert!(
panic!("Too many private account authentication keys"); private_nsks_iter.next().is_none(),
} "Too many private account nullifier secret keys"
);
if private_membership_proofs_iter.next().is_some() { assert!(
panic!("Too many private account membership proofs"); private_membership_proofs_iter.next().is_none(),
} "Too many private account membership proofs"
);
let output = PrivacyPreservingCircuitOutput { output
public_pre_states, }
public_post_states,
ciphertexts, fn compute_nullifier_and_set_digest(
new_commitments, membership_proof_opt: Option<&MembershipProof>,
new_nullifiers, pre_account: &Account,
}; npk: &NullifierPublicKey,
nsk: &NullifierSecretKey,
env::commit(&output); ) -> (Nullifier, CommitmentSetDigest) {
membership_proof_opt
.as_ref()
.map(|membership_proof| {
// Compute commitment set digest associated with provided auth path
let commitment_pre = Commitment::new(npk, pre_account);
let set_digest = compute_digest_for_path(&commitment_pre, membership_proof);
// Compute update nullifier
let nullifier = Nullifier::for_account_update(&commitment_pre, nsk);
(nullifier, set_digest)
})
.unwrap_or_else(|| {
assert_eq!(
*pre_account,
Account::default(),
"Found new private account with non default values"
);
// Compute initialization nullifier
let nullifier = Nullifier::for_account_initialization(npk);
(nullifier, DUMMY_COMMITMENT_HASH)
})
} }