feat: implement private multi chain calls in circuit

This commit is contained in:
Daniil Polyakov 2025-12-24 22:58:33 +03:00
parent 1d09afd9e0
commit 847bd1a376
6 changed files with 294 additions and 217 deletions

View File

@ -15,9 +15,8 @@ pub type Nonce = u128;
/// Account to be used both in public and private contexts
#[derive(
Clone, Default, Eq, PartialEq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, BorshSerialize, BorshDeserialize,
)]
#[cfg_attr(any(feature = "host", test), derive(Debug))]
pub struct Account {
pub program_owner: ProgramId,
pub balance: u128,
@ -25,8 +24,7 @@ pub struct Account {
pub nonce: Nonce,
}
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
#[cfg_attr(any(feature = "host", test), derive(Debug))]
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct AccountWithMetadata {
pub account: Account,
pub is_authorized: bool,
@ -45,6 +43,7 @@ impl AccountWithMetadata {
}
#[derive(
Debug,
Default,
Copy,
Clone,
@ -56,7 +55,7 @@ impl AccountWithMetadata {
BorshSerialize,
BorshDeserialize,
)]
#[cfg_attr(any(feature = "host", test), derive(Debug, PartialOrd, Ord))]
#[cfg_attr(any(feature = "host", test), derive(PartialOrd, Ord))]
pub struct AccountId {
value: [u8; 32],
}

View File

@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize};
pub const DATA_MAX_LENGTH_IN_BYTES: usize = 100 * 1024; // 100 KiB
#[derive(Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)]
#[cfg_attr(any(feature = "host", test), derive(Debug))]
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)]
pub struct Data(Vec<u8>);
impl Data {

View File

@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use crate::{Commitment, account::AccountId};
#[derive(Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(any(feature = "host", test), derive(Debug, Clone, Hash))]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(any(feature = "host", test), derive(Clone, Hash))]
pub struct NullifierPublicKey(pub [u8; 32]);
impl From<&NullifierPublicKey> for AccountId {

View File

@ -108,6 +108,11 @@ impl AccountPostState {
pub fn account_mut(&mut self) -> &mut Account {
&mut self.account
}
/// Consumes the post state and returns the underlying account
pub fn into_account(self) -> Account {
self.account
}
}
#[derive(Serialize, Deserialize, Clone)]

View File

@ -119,7 +119,7 @@ impl PublicTransaction {
return Err(NssaError::MaxChainedCallsDepthExceeded);
}
// Check the `program_id` corresponds to a deployed program
// Check that the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
};
@ -136,11 +136,11 @@ impl PublicTransaction {
);
let authorized_pdas =
self.compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds);
Self::compute_authorized_pdas(&caller_program_id, &chained_call.pda_seeds);
for pre in &program_output.pre_states {
let account_id = pre.account_id;
// Check that the program output pre_states coinicide with the values in the public
// Check that the program output pre_states coincide with the values in the public
// state or with any modifications to those values during the chain of calls.
let expected_pre = state_diff
.get(&account_id)
@ -202,7 +202,6 @@ impl PublicTransaction {
}
fn compute_authorized_pdas(
&self,
caller_program_id: &Option<ProgramId>,
pda_seeds: &[PdaSeed],
) -> HashSet<AccountId> {

View File

@ -1,12 +1,18 @@
use std::collections::HashMap;
use std::{
collections::{HashMap, VecDeque},
convert::Infallible,
};
use nssa_core::{
Commitment, CommitmentSetDigest, DUMMY_COMMITMENT_HASH, EncryptionScheme, Nullifier,
NullifierPublicKey, PrivacyPreservingCircuitInput, PrivacyPreservingCircuitOutput,
account::{Account, AccountId, AccountWithMetadata},
Commitment, CommitmentSetDigest, DUMMY_COMMITMENT_HASH, EncryptionScheme, MembershipProof,
Nullifier, NullifierPublicKey, NullifierSecretKey, PrivacyPreservingCircuitInput,
PrivacyPreservingCircuitOutput, SharedSecretKey,
account::{Account, AccountId, AccountWithMetadata, Nonce},
compute_digest_for_path,
encryption::Ciphertext,
program::{DEFAULT_PROGRAM_ID, MAX_NUMBER_CHAINED_CALLS, validate_execution},
program::{
ChainedCall, DEFAULT_PROGRAM_ID, MAX_NUMBER_CHAINED_CALLS, ProgramId, ProgramOutput,
validate_execution,
},
};
use risc0_zkvm::{guest::env, serde::to_vec};
@ -18,118 +24,172 @@ fn main() {
private_account_keys,
private_account_nsks,
private_account_membership_proofs,
mut program_id,
program_id,
} = env::read();
let mut pre_states: Vec<AccountWithMetadata> = Vec::new();
let mut state_diff: HashMap<AccountId, Account> = HashMap::new();
let execution_state = ExecutionState::derive_from_outputs(program_id, program_outputs);
let num_calls = program_outputs.len();
if num_calls > MAX_NUMBER_CHAINED_CALLS {
panic!("Max chained calls depth is exceeded");
let output = compute_circuit_output(
execution_state,
&visibility_mask,
&private_account_nonces,
&private_account_keys,
&private_account_nsks,
&private_account_membership_proofs,
);
env::commit(&output);
}
/// World state before and after program execution.
struct ExecutionState {
pre_states: Vec<AccountWithMetadata>,
post_states: HashMap<AccountId, Account>,
}
impl ExecutionState {
/// Validate program outputs and derive the overall execution state.
pub fn derive_from_outputs(program_id: ProgramId, program_outputs: Vec<ProgramOutput>) -> Self {
let Some(first_output) = program_outputs.first() else {
panic!("Program outputs is empty")
};
let initial_call = ChainedCall {
program_id,
instruction_data: first_output.instruction_data.clone(),
pre_states: first_output.pre_states.clone(),
pda_seeds: Vec::new(),
};
let mut chained_calls = VecDeque::from_iter([initial_call]);
let mut execution_state = ExecutionState {
pre_states: Vec::new(),
post_states: HashMap::new(),
};
let mut last_program_id = program_id;
let mut program_outputs_iter = program_outputs.into_iter();
let mut chain_calls_counter = 0;
while let Some(chained_call) = chained_calls.pop_front() {
assert!(
chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS,
"Max chained calls depth is exceeded"
);
let Some(program_output) = program_outputs_iter.next() else {
panic!("Insufficient program outputs for chained calls");
};
// Check that instruction data in chained call is the instruction data in program output
assert_eq!(
chained_call.instruction_data, program_output.instruction_data,
"Mismatched instruction data between chained call and program output"
);
// Check that `program_output` is consistent with the execution of the corresponding
// program.
let program_output_words =
&to_vec(&program_output).expect("program_output must be serializable");
env::verify(chained_call.program_id, program_output_words).unwrap_or_else(
|_: Infallible| unreachable!("Infallible error is never constructed"),
);
// TODO: Why private execution doesn't care about public account authorization?
// Check that the program is well behaved.
// See the # Programs section for the definition of the `validate_execution` method.
let execution_valid = validate_execution(
&program_output.pre_states,
&program_output.post_states,
chained_call.program_id,
);
assert!(execution_valid, "Bad behaved program");
for next_call in program_output.chained_calls.iter().rev() {
chained_calls.push_front(next_call.clone());
}
execution_state.populate_from_output(chained_call.program_id, program_output);
last_program_id = chained_call.program_id;
chain_calls_counter += 1;
}
assert!(
program_outputs_iter.next().is_none(),
"Inner call without a chained call found",
);
// Claim accounts
for account in execution_state.post_states.values_mut() {
if account.program_owner == DEFAULT_PROGRAM_ID {
account.program_owner = last_program_id;
}
}
execution_state
}
let Some(last_program_call) = program_outputs.last() else {
panic!("Program outputs is empty")
fn populate_from_output(&mut self, program_id: ProgramId, program_output: ProgramOutput) {
for (pre, mut post) in program_output
.pre_states
.into_iter()
.zip(program_output.post_states)
{
let pre_account_id = pre.account_id;
if let Some(account_pre) = self.post_states.get(&pre_account_id) {
assert_eq!(account_pre, &pre.account, "Inconsistent pre state");
} else {
self.pre_states.push(pre);
}
if post.requires_claim() {
// The invoked program can only claim accounts with default program id.
if post.account().program_owner == DEFAULT_PROGRAM_ID {
post.account_mut().program_owner = program_id;
} else {
panic!("Cannot claim an initialized account")
}
}
self.post_states.insert(pre_account_id, post.into_account());
}
}
/// Get an iterator over pre and post states of each account involved in the execution.
pub fn into_states_iter(
mut self,
) -> impl ExactSizeIterator<Item = (AccountWithMetadata, Account)> {
self.pre_states.into_iter().map(move |pre| {
let post = self
.post_states
.remove(&pre.account_id)
.expect("Account from pre states should exist in state diff");
(pre, post)
})
}
}
fn compute_circuit_output(
execution_state: ExecutionState,
visibility_mask: &[u8],
private_account_nonces: &[Nonce],
private_account_keys: &[(NullifierPublicKey, SharedSecretKey)],
private_account_nsks: &[NullifierSecretKey],
private_account_membership_proofs: &[Option<MembershipProof>],
) -> PrivacyPreservingCircuitOutput {
let mut output = PrivacyPreservingCircuitOutput {
public_pre_states: Vec::new(),
public_post_states: Vec::new(),
ciphertexts: Vec::new(),
new_commitments: Vec::new(),
new_nullifiers: Vec::new(),
};
if !last_program_call.chained_calls.is_empty() {
panic!("Call stack is incomplete");
}
for window in program_outputs.windows(2) {
let caller = &window[0];
let callee = &window[1];
if caller.chained_calls.len() > 1 {
panic!("Privacy Multi-chained calls are not supported yet");
}
// TODO: Modify when multi-chain calls are supported in the circuit
let Some(caller_chained_call) = &caller.chained_calls.first() else {
panic!("Expected chained call");
};
// Check that instruction data in caller is the instruction data in callee
if caller_chained_call.instruction_data != callee.instruction_data {
panic!("Invalid instruction data");
}
// Check that account pre_states in caller are the ones in calle
if caller_chained_call.pre_states != callee.pre_states {
panic!("Invalid pre states");
}
}
for (i, program_output) in program_outputs.iter().enumerate() {
let mut program_output = program_output.clone();
// Check that `program_output` is consistent with the execution of the corresponding
// program.
let program_output_words =
&to_vec(&program_output).expect("program_output must be serializable");
env::verify(program_id, program_output_words)
.expect("program output must match the program's execution");
// Check that the program is well behaved.
// See the # Programs section for the definition of the `validate_execution` method.
if !validate_execution(
&program_output.pre_states,
&program_output.post_states,
program_id,
) {
panic!("Bad behaved program");
}
// The invoked program claims the accounts with default program id.
for post in program_output
.post_states
.iter_mut()
.filter(|post| post.requires_claim())
{
// The invoked program can only claim accounts with default program id.
if post.account().program_owner == DEFAULT_PROGRAM_ID {
post.account_mut().program_owner = program_id;
} else {
panic!("Cannot claim an initialized account")
}
}
for (pre, post) in program_output
.pre_states
.iter()
.zip(&program_output.post_states)
{
if let Some(account_pre) = state_diff.get(&pre.account_id) {
if account_pre != &pre.account {
panic!("Invalid input");
}
} else {
pre_states.push(pre.clone());
}
state_diff.insert(pre.account_id, post.account().clone());
}
// TODO: Modify when multi-chain calls are supported in the circuit
if let Some(next_chained_call) = &program_output.chained_calls.first() {
program_id = next_chained_call.program_id;
} else if i != program_outputs.len() - 1 {
panic!("Inner call without a chained call found")
};
}
let n_accounts = pre_states.len();
if visibility_mask.len() != n_accounts {
panic!("Invalid visibility mask length");
}
// These lists will be the public outputs of this circuit
// and will be populated next.
let mut public_pre_states: Vec<AccountWithMetadata> = Vec::new();
let mut public_post_states: Vec<Account> = Vec::new();
let mut ciphertexts: Vec<Ciphertext> = Vec::new();
let mut new_commitments: Vec<Commitment> = Vec::new();
let mut new_nullifiers: Vec<(Nullifier, CommitmentSetDigest)> = Vec::new();
let states_iter = execution_state.into_states_iter();
assert_eq!(
visibility_mask.len(),
states_iter.len(),
"Invalid visibility mask length"
);
let mut private_nonces_iter = private_account_nonces.iter();
let mut private_keys_iter = private_account_keys.iter();
@ -137,141 +197,156 @@ fn main() {
let mut private_membership_proofs_iter = private_account_membership_proofs.iter();
let mut output_index = 0;
for i in 0..n_accounts {
match visibility_mask[i] {
for (visibility_mask, (pre_state, post_state)) in
visibility_mask.iter().copied().zip(states_iter)
{
match visibility_mask {
0 => {
// Public account
public_pre_states.push(pre_states[i].clone());
let mut post = state_diff.get(&pre_states[i].account_id).unwrap().clone();
if post.program_owner == DEFAULT_PROGRAM_ID {
// Claim account
post.program_owner = program_id;
}
public_post_states.push(post);
output.public_pre_states.push(pre_state);
output.public_post_states.push(post_state);
}
1 | 2 => {
let new_nonce = private_nonces_iter.next().expect("Missing private nonce");
let (npk, shared_secret) = private_keys_iter.next().expect("Missing keys");
let Some((npk, shared_secret)) = private_keys_iter.next() else {
panic!("Missing private account key");
};
if AccountId::from(npk) != pre_states[i].account_id {
panic!("AccountId mismatch");
}
assert_eq!(
AccountId::from(npk),
pre_state.account_id,
"AccountId mismatch"
);
if visibility_mask[i] == 1 {
let new_nullifier = if visibility_mask == 1 {
// Private account with authentication
let nsk = private_nsks_iter.next().expect("Missing nsk");
let Some(nsk) = private_nsks_iter.next() else {
panic!("Missing private account nullifier secret key");
};
// Verify the nullifier public key
let expected_npk = NullifierPublicKey::from(nsk);
if &expected_npk != npk {
panic!("Nullifier public key mismatch");
}
assert_eq!(
npk,
&NullifierPublicKey::from(nsk),
"Nullifier public key mismatch"
);
// Check pre_state authorization
if !pre_states[i].is_authorized {
panic!("Pre-state not authorized");
}
assert!(
pre_state.is_authorized,
"Pre-state not authorized for authenticated private account"
);
let membership_proof_opt = private_membership_proofs_iter
.next()
.expect("Missing membership proof");
let (nullifier, set_digest) = membership_proof_opt
.as_ref()
.map(|membership_proof| {
// Compute commitment set digest associated with provided auth path
let commitment_pre = Commitment::new(npk, &pre_states[i].account);
let set_digest =
compute_digest_for_path(&commitment_pre, membership_proof);
let Some(membership_proof_opt) = private_membership_proofs_iter.next() else {
panic!("Missing membership proof");
};
// Compute update nullifier
let nullifier = Nullifier::for_account_update(&commitment_pre, nsk);
(nullifier, set_digest)
})
.unwrap_or_else(|| {
if pre_states[i].account != Account::default() {
panic!("Found new private account with non default values.");
}
// Compute initialization nullifier
let nullifier = Nullifier::for_account_initialization(npk);
(nullifier, DUMMY_COMMITMENT_HASH)
});
new_nullifiers.push((nullifier, set_digest));
compute_nullifier_and_set_digest(
membership_proof_opt.as_ref(),
&pre_state.account,
npk,
nsk,
)
} else {
// Private account without authentication
if pre_states[i].account != Account::default() {
panic!("Found new private account with non default values.");
}
if pre_states[i].is_authorized {
panic!("Found new private account marked as authorized.");
}
assert_eq!(
pre_state.account,
Account::default(),
"Found new private account with non default values",
);
assert!(
!pre_state.is_authorized,
"Found new private account marked as authorized."
);
let Some(membership_proof_opt) = private_membership_proofs_iter.next() else {
panic!("Missing membership proof");
};
let membership_proof_opt = private_membership_proofs_iter
.next()
.expect("Missing membership proof");
assert!(
membership_proof_opt.is_none(),
"Membership proof must be None for unauthorized accounts"
);
let nullifier = Nullifier::for_account_initialization(npk);
new_nullifiers.push((nullifier, DUMMY_COMMITMENT_HASH));
}
(nullifier, DUMMY_COMMITMENT_HASH)
};
output.new_nullifiers.push(new_nullifier);
// Update post-state with new nonce
let mut post_with_updated_values =
state_diff.get(&pre_states[i].account_id).unwrap().clone();
post_with_updated_values.nonce = *new_nonce;
if post_with_updated_values.program_owner == DEFAULT_PROGRAM_ID {
// Claim account
post_with_updated_values.program_owner = program_id;
}
let mut post_with_updated_nonce = post_state;
let Some(new_nonce) = private_nonces_iter.next() else {
panic!("Missing private account nonce");
};
post_with_updated_nonce.nonce = *new_nonce;
// Compute commitment
let commitment_post = Commitment::new(npk, &post_with_updated_values);
let commitment_post = Commitment::new(npk, &post_with_updated_nonce);
// Encrypt and push post state
let encrypted_account = EncryptionScheme::encrypt(
&post_with_updated_values,
&post_with_updated_nonce,
shared_secret,
&commitment_post,
output_index,
);
new_commitments.push(commitment_post);
ciphertexts.push(encrypted_account);
output.new_commitments.push(commitment_post);
output.ciphertexts.push(encrypted_account);
output_index += 1;
}
_ => panic!("Invalid visibility mask value"),
}
}
if private_nonces_iter.next().is_some() {
panic!("Too many nonces");
}
assert!(private_nonces_iter.next().is_none(), "Too many nonces");
if private_keys_iter.next().is_some() {
panic!("Too many private account keys");
}
assert!(
private_keys_iter.next().is_none(),
"Too many private account keys"
);
if private_nsks_iter.next().is_some() {
panic!("Too many private account authentication keys");
}
assert!(
private_nsks_iter.next().is_none(),
"Too many private account nullifier secret keys"
);
if private_membership_proofs_iter.next().is_some() {
panic!("Too many private account membership proofs");
}
assert!(
private_membership_proofs_iter.next().is_none(),
"Too many private account membership proofs"
);
let output = PrivacyPreservingCircuitOutput {
public_pre_states,
public_post_states,
ciphertexts,
new_commitments,
new_nullifiers,
};
env::commit(&output);
output
}
fn compute_nullifier_and_set_digest(
membership_proof_opt: Option<&MembershipProof>,
pre_account: &Account,
npk: &NullifierPublicKey,
nsk: &NullifierSecretKey,
) -> (Nullifier, CommitmentSetDigest) {
membership_proof_opt
.as_ref()
.map(|membership_proof| {
// Compute commitment set digest associated with provided auth path
let commitment_pre = Commitment::new(npk, pre_account);
let set_digest = compute_digest_for_path(&commitment_pre, membership_proof);
// Compute update nullifier
let nullifier = Nullifier::for_account_update(&commitment_pre, nsk);
(nullifier, set_digest)
})
.unwrap_or_else(|| {
assert_eq!(
*pre_account,
Account::default(),
"Found new private account with non default values"
);
// Compute initialization nullifier
let nullifier = Nullifier::for_account_initialization(npk);
(nullifier, DUMMY_COMMITMENT_HASH)
})
}