refactor: move private PDA npk into proven ChainedCall and Claim

This commit is contained in:
Moudy 2026-04-16 16:53:54 +02:00
parent e08c8f93b4
commit bda21fb5c5
12 changed files with 147 additions and 106 deletions

View File

@ -5,7 +5,7 @@ use crate::{
NullifierSecretKey, SharedSecretKey,
account::{Account, AccountWithMetadata},
encryption::Ciphertext,
program::{BlockValidityWindow, PdaSeed, ProgramId, ProgramOutput, TimestampValidityWindow},
program::{BlockValidityWindow, ProgramId, ProgramOutput, TimestampValidityWindow},
};
#[derive(Serialize, Deserialize)]
@ -17,6 +17,7 @@ pub struct PrivacyPreservingCircuitInput {
/// - `0` - public account
/// - `1` - private account with authentication
/// - `2` - private account without authentication
/// - `3` - private PDA account
pub visibility_mask: Vec<u8>,
/// Public keys of private accounts.
pub private_account_keys: Vec<(NullifierPublicKey, SharedSecretKey)>,
@ -26,12 +27,6 @@ pub struct PrivacyPreservingCircuitInput {
pub private_account_membership_proofs: Vec<Option<MembershipProof>>,
/// Program ID.
pub program_id: ProgramId,
/// Private PDA info for mask-3 accounts.
/// Unlike the other `private_account_*` fields which are parallel arrays indexed by private
/// account position, this is a separate lookup table. The circuit matches entries by
/// (`program_id`, `seed`) against the chained calls' `pda_seeds` to resolve private PDA
/// authorization.
pub private_pda_info: Vec<(ProgramId, PdaSeed, NullifierPublicKey)>,
}
#[derive(Serialize, Deserialize)]

View File

@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use crate::{Commitment, account::AccountId};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(any(feature = "host", test), derive(Clone, Hash))]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(any(feature = "host", test), derive(Hash))]
pub struct NullifierPublicKey(pub [u8; 32]);
impl From<&NullifierPublicKey> for AccountId {

View File

@ -65,7 +65,14 @@ pub struct ChainedCall {
pub pre_states: Vec<AccountWithMetadata>,
/// The instruction data to pass.
pub instruction_data: InstructionData,
/// Public PDA seeds authorized for the callee. Each derives an `AccountId` via
/// `AccountId::from((&caller_program_id, seed))`.
pub pda_seeds: Vec<PdaSeed>,
/// Private PDA `(seed, npk)` pairs authorized for the callee. Each derives an `AccountId`
/// via `private_pda_account_id(&caller_program_id, seed, npk)`. The npk binds the
/// authorization to a specific group of controllers and is part of the caller program's
/// Risc0-proven output, so the outer circuit can trust it.
pub private_pda_seeds: Vec<(PdaSeed, NullifierPublicKey)>,
}
impl ChainedCall {
@ -81,6 +88,7 @@ impl ChainedCall {
instruction_data: risc0_zkvm::serde::to_vec(instruction)
.expect("Serialization to Vec<u32> should not fail"),
pda_seeds: Vec::new(),
private_pda_seeds: Vec::new(),
}
}
@ -89,6 +97,15 @@ impl ChainedCall {
self.pda_seeds = pda_seeds;
self
}
#[must_use]
pub fn with_private_pda_seeds(
mut self,
private_pda_seeds: Vec<(PdaSeed, NullifierPublicKey)>,
) -> Self {
self.private_pda_seeds = private_pda_seeds;
self
}
}
/// Represents the final state of an `Account` after a program execution.
@ -114,8 +131,16 @@ pub enum Claim {
/// This will give no error if program had authorization in pre state and may be useful
/// if program decides to give up authorization for a chained call.
Authorized,
/// The program requests ownership of the account through a PDA.
/// The program requests ownership of the account through a public PDA. The `AccountId` is
/// `AccountId::from((&program_id, &seed))`.
Pda(PdaSeed),
/// The program requests ownership of the account through a private PDA. The `AccountId` is
/// `private_pda_account_id(&program_id, &seed, &npk)`. The npk is part of the program's
/// Risc0-proven output, so the outer circuit can trust it.
PrivatePda {
seed: PdaSeed,
npk: NullifierPublicKey,
},
}
impl AccountPostState {
@ -436,29 +461,27 @@ pub fn private_pda_account_id(
)
}
/// Computes the set of PDA `AccountId`s the callee is authorized to mutate.
///
/// `pda_seeds` produces public PDAs. `private_pda_seeds` produces private PDAs whose derivation
/// includes the caller-supplied npk. All seeds and npks must come from the caller's Risc0-proven
/// [`ChainedCall`], so the outer circuit can trust them.
#[must_use]
pub fn compute_authorized_pdas(
caller_program_id: Option<ProgramId>,
pda_seeds: &[PdaSeed],
private_pda_info: &[(ProgramId, PdaSeed, NullifierPublicKey)],
private_pda_seeds: &[(PdaSeed, NullifierPublicKey)],
) -> HashSet<AccountId> {
caller_program_id
.map(|caller_program_id| {
pda_seeds
.iter()
.map(|pda_seed| {
if let Some((_, _, npk)) = private_pda_info
.iter()
.find(|(pid, s, _)| *pid == caller_program_id && s == pda_seed)
{
private_pda_account_id(&caller_program_id, pda_seed, npk)
} else {
AccountId::from((&caller_program_id, pda_seed))
}
})
.collect()
})
.unwrap_or_default()
let Some(caller) = caller_program_id else {
return HashSet::new();
};
let public = pda_seeds
.iter()
.map(|seed| AccountId::from((&caller, seed)));
let private = private_pda_seeds
.iter()
.map(|(seed, npk)| private_pda_account_id(&caller, seed, npk));
public.chain(private).collect()
}
/// Reads the NSSA inputs from the guest environment.
@ -824,59 +847,52 @@ mod tests {
assert_ne!(private_pda_id, standard_private_id);
}
// ---- compute_authorized_pdas with private_pda_info tests ----
// ---- compute_authorized_pdas tests ----
/// With no private PDA info, `compute_authorized_pdas` returns public PDA addresses
/// (backward compatible with the existing behavior).
/// With no private PDA seeds, `compute_authorized_pdas` returns public PDA addresses only.
#[test]
fn compute_authorized_pdas_empty_private_info_returns_public_ids() {
fn compute_authorized_pdas_public_only() {
let caller: ProgramId = [1; 8];
let seed = PdaSeed::new([2; 32]);
let result = compute_authorized_pdas(Some(caller), &[seed], &[]);
let expected = AccountId::from((&caller, &seed));
assert!(result.contains(&expected));
assert_eq!(result.len(), 1);
}
/// When a `pda_seed` matches a `private_pda_info` entry, the result uses the private PDA
/// formula (with `npk`) instead of the public formula.
/// Private PDA seeds produce private PDA `AccountId`s via the `npk`-inclusive derivation.
#[test]
fn compute_authorized_pdas_matching_entry_returns_private_id() {
fn compute_authorized_pdas_private_only() {
let caller: ProgramId = [1; 8];
let seed = PdaSeed::new([2; 32]);
let npk = NullifierPublicKey([3; 32]);
let info = vec![(caller, seed, npk.clone())];
let result = compute_authorized_pdas(Some(caller), &[seed], &info);
let result = compute_authorized_pdas(Some(caller), &[], &[(seed, npk)]);
let expected = private_pda_account_id(&caller, &seed, &npk);
assert!(result.contains(&expected));
// Should NOT contain the public PDA
let public_id = AccountId::from((&caller, &seed));
assert!(!result.contains(&public_id));
assert_eq!(result.len(), 1);
}
/// When a `pda_seed` does NOT match any `private_pda_info` entry, the result uses the
/// standard public PDA formula (no `npk`).
/// Public and private seeds can coexist in a single chained call; both are authorized.
#[test]
fn compute_authorized_pdas_non_matching_entry_returns_public_id() {
fn compute_authorized_pdas_public_and_private() {
let caller: ProgramId = [1; 8];
let seed_a = PdaSeed::new([2; 32]);
let seed_b = PdaSeed::new([9; 32]);
let pub_seed = PdaSeed::new([2; 32]);
let priv_seed = PdaSeed::new([4; 32]);
let npk = NullifierPublicKey([3; 32]);
// Info is for seed_b, but we authorize seed_a
let info = vec![(caller, seed_b, npk)];
let result = compute_authorized_pdas(Some(caller), &[seed_a], &info);
let expected = AccountId::from((&caller, &seed_a));
assert!(result.contains(&expected));
let result = compute_authorized_pdas(Some(caller), &[pub_seed], &[(priv_seed, npk)]);
assert!(result.contains(&AccountId::from((&caller, &pub_seed))));
assert!(result.contains(&private_pda_account_id(&caller, &priv_seed, &npk)));
assert_eq!(result.len(), 2);
}
/// With no caller (top-level call), the result is always empty regardless of
/// `private_pda_info`.
/// With no caller (top-level call), the result is always empty.
#[test]
fn compute_authorized_pdas_no_caller_returns_empty() {
let seed = PdaSeed::new([2; 32]);
let npk = NullifierPublicKey([3; 32]);
let caller: ProgramId = [1; 8];
let info = vec![(caller, seed, npk)];
let result = compute_authorized_pdas(None, &[seed], &info);
let result = compute_authorized_pdas(None, &[seed], &[(seed, npk)]);
assert!(result.is_empty());
}
}

View File

@ -85,6 +85,7 @@ pub fn execute_and_prove(
instruction_data,
pre_states,
pda_seeds: vec![],
private_pda_seeds: vec![],
};
let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program, None)]);
@ -131,7 +132,6 @@ pub fn execute_and_prove(
private_account_nsks,
private_account_membership_proofs,
program_id: program_with_dependencies.program.id(),
private_pda_info: vec![],
};
env_builder.write(&circuit_input).unwrap();

View File

@ -8,7 +8,8 @@ use nssa_core::{
BlockId, Commitment, Nullifier, PrivacyPreservingCircuitOutput, Timestamp,
account::{Account, AccountId, AccountWithMetadata},
program::{
ChainedCall, Claim, DEFAULT_PROGRAM_ID, compute_authorized_pdas, validate_execution,
ChainedCall, Claim, DEFAULT_PROGRAM_ID, compute_authorized_pdas, private_pda_account_id,
validate_execution,
},
};
@ -98,6 +99,7 @@ impl ValidatedStateDiff {
instruction_data: message.instruction_data.clone(),
pre_states: input_pre_states,
pda_seeds: vec![],
private_pda_seeds: vec![],
};
let mut chained_calls = VecDeque::from_iter([(initial_call, None)]);
@ -128,8 +130,11 @@ impl ValidatedStateDiff {
chained_call.program_id, program_output
);
let authorized_pdas =
compute_authorized_pdas(caller_program_id, &chained_call.pda_seeds, &[]);
let authorized_pdas = compute_authorized_pdas(
caller_program_id,
&chained_call.pda_seeds,
&chained_call.private_pda_seeds,
);
let is_authorized = |account_id: &AccountId| {
signer_account_ids.contains(account_id) || authorized_pdas.contains(account_id)
@ -214,6 +219,10 @@ impl ValidatedStateDiff {
let pda = AccountId::from((&chained_call.program_id, &seed));
ensure!(account_id == pda, NssaError::InvalidProgramBehavior);
}
Claim::PrivatePda { seed, npk } => {
let pda = private_pda_account_id(&chained_call.program_id, &seed, &npk);
ensure!(account_id == pda, NssaError::InvalidProgramBehavior);
}
}
post.account_mut().program_owner = chained_call.program_id;

View File

@ -11,8 +11,8 @@ use nssa_core::{
compute_digest_for_path,
program::{
AccountPostState, BlockValidityWindow, ChainedCall, Claim, DEFAULT_PROGRAM_ID,
MAX_NUMBER_CHAINED_CALLS, PdaSeed, ProgramId, ProgramOutput, TimestampValidityWindow,
validate_execution,
MAX_NUMBER_CHAINED_CALLS, ProgramId, ProgramOutput, TimestampValidityWindow,
private_pda_account_id, validate_execution,
},
};
use risc0_zkvm::{guest::env, serde::to_vec};
@ -23,6 +23,12 @@ struct ExecutionState {
post_states: HashMap<AccountId, Account>,
block_validity_window: BlockValidityWindow,
timestamp_validity_window: TimestampValidityWindow,
/// Map from private-PDA `AccountId` to the npk used to derive it, sourced entirely from
/// Risc0-proven `Claim::PrivatePda` in post_states and `private_pda_seeds` in chained
/// calls. `compute_circuit_output` uses this to verify that the npk supplied via
/// `private_account_keys` for a mask-3 account matches the npk attested by some program's
/// proof.
private_pda_bindings: HashMap<AccountId, NullifierPublicKey>,
}
impl ExecutionState {
@ -31,7 +37,6 @@ impl ExecutionState {
visibility_mask: &[u8],
program_id: ProgramId,
program_outputs: Vec<ProgramOutput>,
private_pda_info: &[(ProgramId, PdaSeed, NullifierPublicKey)],
) -> Self {
let block_valid_from = program_outputs
.iter()
@ -67,6 +72,7 @@ impl ExecutionState {
post_states: HashMap::new(),
block_validity_window,
timestamp_validity_window,
private_pda_bindings: HashMap::new(),
};
let Some(first_output) = program_outputs.first() else {
@ -78,6 +84,7 @@ impl ExecutionState {
instruction_data: first_output.instruction_data.clone(),
pre_states: first_output.pre_states.clone(),
pda_seeds: Vec::new(),
private_pda_seeds: Vec::new(),
};
let mut chained_calls = VecDeque::from_iter([(initial_call, None)]);
@ -133,6 +140,27 @@ impl ExecutionState {
);
assert!(execution_valid, "Bad behaved program");
// Collect private-PDA bindings from this program_output's proven data. Each
// `private_pda_seeds` entry in an outgoing chained call attests that the caller
// (this program) authorizes the callee to mutate the PDA derived from
// `(self_program_id, seed, npk)`. Each `Claim::PrivatePda` in this program's
// post_states attests that it claims the PDA derived from the same formula with
// its own program_id.
for next_call in &program_output.chained_calls {
for (seed, npk) in &next_call.private_pda_seeds {
let account_id = private_pda_account_id(&chained_call.program_id, seed, npk);
execution_state
.private_pda_bindings
.insert(account_id, *npk);
}
}
for post in &program_output.post_states {
if let Some(Claim::PrivatePda { seed, npk }) = post.required_claim() {
let account_id = private_pda_account_id(&chained_call.program_id, &seed, &npk);
execution_state.private_pda_bindings.insert(account_id, npk);
}
}
for next_call in program_output.chained_calls.iter().rev() {
chained_calls.push_front((next_call.clone(), Some(chained_call.program_id)));
}
@ -140,13 +168,12 @@ impl ExecutionState {
let authorized_pdas = nssa_core::program::compute_authorized_pdas(
caller_program_id,
&chained_call.pda_seeds,
private_pda_info,
&chained_call.private_pda_seeds,
);
execution_state.validate_and_sync_states(
visibility_mask,
chained_call.program_id,
&authorized_pdas,
private_pda_info,
program_output.pre_states,
program_output.post_states,
);
@ -190,7 +217,6 @@ impl ExecutionState {
visibility_mask: &[u8],
program_id: ProgramId,
authorized_pdas: &HashSet<AccountId>,
private_pda_info: &[(ProgramId, PdaSeed, NullifierPublicKey)],
pre_states: Vec<AccountWithMetadata>,
post_states: Vec<AccountPostState>,
) {
@ -275,27 +301,30 @@ impl ExecutionState {
"Invalid PDA claim for account {pre_account_id} which does not match derived PDA {pda}"
);
}
Claim::PrivatePda { .. } => {
panic!(
"Public account {pre_account_id} cannot be claimed via Claim::PrivatePda"
);
}
}
} else if is_private_pda {
match claim {
Claim::Pda(seed) => {
let (_, _, npk) = private_pda_info
.iter()
.find(|(pid, s, _)| *pid == program_id && s == &seed)
.expect(
"mask-3 PDA claim must have a matching private_pda_info entry",
);
let pda =
nssa_core::program::private_pda_account_id(&program_id, &seed, npk);
Claim::Authorized => {
assert!(
pre_is_authorized,
"Cannot claim unauthorized private PDA {pre_account_id}"
);
}
Claim::PrivatePda { seed, npk } => {
let pda = private_pda_account_id(&program_id, &seed, &npk);
assert_eq!(
pre_account_id, pda,
"Invalid private PDA claim for account {pre_account_id}"
);
}
Claim::Authorized => {
assert!(
pre_is_authorized,
"Cannot claim unauthorized private PDA {pre_account_id}"
Claim::Pda(_) => {
panic!(
"Private PDA {pre_account_id} must be claimed via Claim::PrivatePda, not Claim::Pda"
);
}
}
@ -325,12 +354,11 @@ impl ExecutionState {
}
fn compute_circuit_output(
execution_state: ExecutionState,
mut execution_state: ExecutionState,
visibility_mask: &[u8],
private_account_keys: &[(NullifierPublicKey, SharedSecretKey)],
private_account_nsks: &[NullifierSecretKey],
private_account_membership_proofs: &[Option<MembershipProof>],
private_pda_info: &[(ProgramId, PdaSeed, NullifierPublicKey)],
) -> PrivacyPreservingCircuitOutput {
let mut output = PrivacyPreservingCircuitOutput {
public_pre_states: Vec::new(),
@ -341,6 +369,7 @@ fn compute_circuit_output(
block_validity_window: execution_state.block_validity_window,
timestamp_validity_window: execution_state.timestamp_validity_window,
};
let private_pda_bindings = std::mem::take(&mut execution_state.private_pda_bindings);
let states_iter = execution_state.into_states_iter();
assert_eq!(
@ -461,20 +490,20 @@ fn compute_circuit_output(
.unwrap_or_else(|| panic!("Too many private accounts, output index overflow"));
}
3 => {
// Private PDA account
// Private PDA account. The npk supplied via private_account_keys must match the
// npk attested by some program's Risc0-proven output (either a `Claim::PrivatePda`
// in post_states or a `private_pda_seeds` entry in a chained call). The bindings
// map is built entirely from proven data in `derive_from_outputs`.
let Some((npk, shared_secret)) = private_keys_iter.next() else {
panic!("Missing private account key");
};
// Verify AccountId against private PDA formula
let (pda_program_id, pda_seed, _) = private_pda_info
.iter()
.find(|(_, _, info_npk)| info_npk == npk)
.expect("mask-3 account must have a matching private_pda_info entry");
let attested_npk = private_pda_bindings.get(&pre_state.account_id).expect(
"mask-3 account must be attested by a proven Claim::PrivatePda or ChainedCall.private_pda_seeds entry",
);
assert_eq!(
nssa_core::program::private_pda_account_id(pda_program_id, pda_seed, npk),
pre_state.account_id,
"Private PDA AccountId mismatch"
npk, attested_npk,
"Private PDA npk does not match proven attestation for {}",
pre_state.account_id
);
let (new_nullifier, new_nonce) = if pre_state.is_authorized {
@ -600,25 +629,10 @@ fn main() {
private_account_nsks,
private_account_membership_proofs,
program_id,
private_pda_info,
} = env::read();
// Validate no duplicate (program_id, seed) pairs in private_pda_info
for (i, (pid_a, seed_a, _)) in private_pda_info.iter().enumerate() {
assert!(
!private_pda_info[..i]
.iter()
.any(|(pid_b, seed_b, _)| pid_a == pid_b && seed_a == seed_b),
"Duplicate (program_id, seed) in private_pda_info"
);
}
let execution_state = ExecutionState::derive_from_outputs(
&visibility_mask,
program_id,
program_outputs,
&private_pda_info,
);
let execution_state =
ExecutionState::derive_from_outputs(&visibility_mask, program_id, program_outputs);
let output = compute_circuit_output(
execution_state,
@ -626,7 +640,6 @@ fn main() {
&private_account_keys,
&private_account_nsks,
&private_account_membership_proofs,
&private_pda_info,
);
env::commit(&output);

View File

@ -41,6 +41,7 @@ fn main() {
instruction_data: instruction_data.clone(),
pre_states: vec![running_sender_pre.clone(), running_recipient_pre.clone()], /* <- Account order permutation here */
pda_seeds: pda_seed.iter().copied().collect(),
private_pda_seeds: vec![],
};
chained_calls.push(new_chained_call);

View File

@ -32,6 +32,7 @@ fn main() {
instruction_data: to_vec(&timestamp).unwrap(),
pre_states: pre_states.clone(),
pda_seeds: vec![],
private_pda_seeds: vec![],
};
ProgramOutput::new(

View File

@ -71,6 +71,7 @@ fn main() {
pre_states: vec![receiver_authorized, vault_pre.clone()],
instruction_data: transfer_instruction,
pda_seeds: vec![PdaSeed::new([1_u8; 32])],
private_pda_seeds: vec![],
});
}
// Malicious path (return_funds = false): emit no chained calls.

View File

@ -129,6 +129,7 @@ fn main() {
pre_states: vec![vault_authorized, receiver_pre.clone()],
instruction_data: transfer_instruction,
pda_seeds: vec![PdaSeed::new([0_u8; 32])],
private_pda_seeds: vec![],
};
// Chained call 2: User callback.
@ -139,6 +140,7 @@ fn main() {
pre_states: vec![vault_after_transfer, receiver_after_transfer],
instruction_data: callback_instruction_data,
pda_seeds: vec![],
private_pda_seeds: vec![],
};
// Chained call 3: Self-call to enforce the invariant.
@ -157,6 +159,7 @@ fn main() {
pre_states: vec![vault_after_callback],
instruction_data: invariant_instruction,
pda_seeds: vec![],
private_pda_seeds: vec![],
};
// The initiator itself makes no direct state changes.

View File

@ -39,6 +39,7 @@ fn main() {
instruction_data,
pre_states: vec![authorised_sender, receiver.clone()],
pda_seeds: vec![],
private_pda_seeds: vec![],
};
ProgramOutput::new(

View File

@ -37,6 +37,7 @@ fn main() {
instruction_data: chained_instruction,
pre_states,
pda_seeds: vec![],
private_pda_seeds: vec![],
};
ProgramOutput::new(