logos-execution-zone/nssa/src/validated_state_diff.rs
2026-05-19 13:43:19 -04:00

601 lines
22 KiB
Rust

use std::{
collections::{HashMap, HashSet, VecDeque},
hash::Hash,
};
use log::debug;
use nssa_core::{
BlockId, Commitment, Nullifier, PrivacyPreservingCircuitOutput, Timestamp,
account::{Account, AccountId, AccountWithMetadata},
program::{
ChainedCall, Claim, DEFAULT_PROGRAM_ID, ProgramId, compute_public_authorized_pdas,
validate_execution,
},
};
use crate::{
V03State, ensure,
error::{InvalidProgramBehaviorError, NssaError},
privacy_preserving_transaction::{
PrivacyPreservingTransaction, circuit::Proof, message::Message,
},
program::Program,
program_deployment_transaction::ProgramDeploymentTransaction,
public_transaction::PublicTransaction,
state::MAX_NUMBER_CHAINED_CALLS,
};
pub struct StateDiff {
pub signer_account_ids: Vec<AccountId>,
pub public_diff: HashMap<AccountId, Account>,
pub new_commitments: Vec<Commitment>,
pub new_nullifiers: Vec<Nullifier>,
pub program: Option<Program>,
}
/// The validated output of executing or verifying a transaction, ready to be applied to the state.
///
/// Can only be constructed by the transaction validation functions inside this crate, ensuring the
/// diff has been checked before any state mutation occurs.
pub struct ValidatedStateDiff(StateDiff);
impl ValidatedStateDiff {
pub fn from_public_transaction(
tx: &PublicTransaction,
state: &V03State,
block_id: BlockId,
timestamp: Timestamp,
) -> Result<Self, NssaError> {
let message = tx.message();
let witness_set = tx.witness_set();
// All account_ids must be different
ensure!(
message.account_ids.iter().collect::<HashSet<_>>().len() == message.account_ids.len(),
NssaError::InvalidInput("Duplicate account_ids found in message".into(),)
);
// Check exactly one nonce is provided for each signature
ensure!(
message.nonces.len() == witness_set.signatures_and_public_keys.len(),
NssaError::InvalidInput(
"Mismatch between number of nonces and signatures/public keys".into(),
)
);
// Check the signatures are valid
ensure!(
witness_set.is_valid_for(message),
NssaError::InvalidInput("Invalid signature for given message and public key".into())
);
let signer_account_ids = tx.signer_account_ids();
// Check nonces corresponds to the current nonces on the public state.
for (account_id, nonce) in signer_account_ids.iter().zip(&message.nonces) {
let current_nonce = state.get_account_by_id(*account_id).nonce;
ensure!(
current_nonce == *nonce,
NssaError::InvalidInput("Nonce mismatch".into())
);
}
// Build pre_states for execution
let input_pre_states: Vec<_> = message
.account_ids
.iter()
.map(|account_id| {
AccountWithMetadata::new(
state.get_account_by_id(*account_id),
signer_account_ids.contains(account_id),
*account_id,
)
})
.collect();
let mut state_diff: HashMap<AccountId, Account> = HashMap::new();
let initial_call = ChainedCall {
program_id: message.program_id,
instruction_data: message.instruction_data.clone(),
pre_states: input_pre_states,
pda_seeds: vec![],
};
#[expect(
clippy::items_after_statements,
reason = "More readable to keep it behind the place where it's used"
)]
#[derive(Debug)]
struct CallerData {
program_id: Option<ProgramId>,
authorized_accounts: HashSet<AccountId>,
}
let initial_caller_data = CallerData {
program_id: None,
authorized_accounts: signer_account_ids.iter().copied().collect(),
};
let mut chained_calls =
VecDeque::<(ChainedCall, CallerData)>::from_iter([(initial_call, initial_caller_data)]);
let mut chain_calls_counter = 0;
while let Some((chained_call, caller_data)) = chained_calls.pop_front() {
ensure!(
chain_calls_counter <= MAX_NUMBER_CHAINED_CALLS,
NssaError::MaxChainedCallsDepthExceeded
);
// Check that the `program_id` corresponds to a deployed program
let Some(program) = state.programs().get(&chained_call.program_id) else {
return Err(NssaError::InvalidInput("Unknown program".into()));
};
debug!(
"Program {:?} pre_states: {:?}, instruction_data: {:?}",
chained_call.program_id, chained_call.pre_states, chained_call.instruction_data
);
let mut program_output = program.execute(
caller_data.program_id,
&chained_call.pre_states,
&chained_call.instruction_data,
)?;
debug!(
"Program {:?} output: {:?}",
chained_call.program_id, program_output
);
let authorized_pdas =
compute_public_authorized_pdas(caller_data.program_id, &chained_call.pda_seeds);
// Account is authorized if it is either in the caller's authorized accounts or in the
// list of PDAs the caller has authorized.
let is_authorized = |account_id: &AccountId| {
authorized_pdas.contains(account_id)
|| caller_data.authorized_accounts.contains(account_id)
};
for pre in &program_output.pre_states {
let account_id = pre.account_id;
// Check that the program output pre_states coincide with the values in the public
// state or with any modifications to those values during the chain of calls.
let expected_pre = state_diff
.get(&account_id)
.cloned()
.unwrap_or_else(|| state.get_account_by_id(account_id));
ensure!(
pre.account == expected_pre,
InvalidProgramBehaviorError::InconsistentAccountPreState {
account_id,
expected: Box::new(expected_pre),
actual: Box::new(pre.account.clone())
}
);
// Check that the program output pre_states marked as authorized are indeed
// authorized.
let is_indeed_authorized = is_authorized(&account_id);
ensure!(
!pre.is_authorized || is_indeed_authorized,
InvalidProgramBehaviorError::InvalidAccountAuthorization { account_id }
);
}
// Verify that the program output's self_program_id matches the expected program ID.
ensure!(
program_output.self_program_id == chained_call.program_id,
InvalidProgramBehaviorError::MismatchedProgramId {
expected: chained_call.program_id,
actual: program_output.self_program_id
}
);
// Verify that the program output's caller_program_id matches the actual caller.
ensure!(
program_output.caller_program_id == caller_data.program_id,
InvalidProgramBehaviorError::MismatchedCallerProgramId {
expected: caller_data.program_id,
actual: program_output.caller_program_id,
}
);
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
validate_execution(
&program_output.pre_states,
&program_output.post_states,
chained_call.program_id,
)
.map_err(InvalidProgramBehaviorError::ExecutionValidationFailed)?;
// Verify validity window
ensure!(
program_output.block_validity_window.is_valid_for(block_id)
&& program_output
.timestamp_validity_window
.is_valid_for(timestamp),
NssaError::OutOfValidityWindow
);
for (i, post) in program_output.post_states.iter_mut().enumerate() {
let Some(claim) = post.required_claim() else {
continue;
};
let pre = &program_output.pre_states[i];
let account_id = pre.account_id;
// The invoked program can only claim accounts with default program id.
ensure!(
post.account().program_owner == DEFAULT_PROGRAM_ID,
InvalidProgramBehaviorError::ClaimedNonDefaultAccount { account_id }
);
match claim {
Claim::Authorized => {
// The program can only claim accounts that were authorized by the signer.
ensure!(
pre.is_authorized,
InvalidProgramBehaviorError::ClaimedUnauthorizedAccount { account_id }
);
}
Claim::Pda(seed) => {
// The program can only claim accounts that correspond to the PDAs it is
// authorized to claim. The public-execution path only sees public
// accounts, so the public-PDA derivation is the correct formula here.
let pda = AccountId::for_public_pda(&chained_call.program_id, &seed);
ensure!(
account_id == pda,
InvalidProgramBehaviorError::MismatchedPdaClaim {
expected: pda,
actual: account_id
}
);
}
}
post.account_mut().program_owner = chained_call.program_id;
}
// Update the state diff
for (pre, post) in program_output
.pre_states
.iter()
.zip(program_output.post_states.iter())
{
state_diff.insert(pre.account_id, post.account().clone());
}
let authorized_accounts: HashSet<_> = program_output
.pre_states
.iter()
.filter(|pre| pre.is_authorized)
.map(|pre| pre.account_id)
.collect();
for new_call in program_output.chained_calls.into_iter().rev() {
chained_calls.push_front((
new_call,
CallerData {
program_id: Some(chained_call.program_id),
authorized_accounts: authorized_accounts.clone(),
},
));
}
chain_calls_counter = chain_calls_counter
.checked_add(1)
.expect("we check the max depth at the beginning of the loop");
}
// Check that all modified uninitialized accounts where claimed
for (account_id, post) in state_diff.iter().filter_map(|(account_id, post)| {
let pre = state.get_account_by_id(*account_id);
if pre.program_owner != DEFAULT_PROGRAM_ID {
return None;
}
if pre == *post {
return None;
}
Some((*account_id, post))
}) {
ensure!(
post.program_owner != DEFAULT_PROGRAM_ID,
InvalidProgramBehaviorError::DefaultAccountModifiedWithoutClaim { account_id }
);
}
Ok(Self(StateDiff {
signer_account_ids,
public_diff: state_diff,
new_commitments: vec![],
new_nullifiers: vec![],
program: None,
}))
}
pub fn from_privacy_preserving_transaction(
tx: &PrivacyPreservingTransaction,
state: &V03State,
block_id: BlockId,
timestamp: Timestamp,
) -> Result<Self, NssaError> {
let message = &tx.message;
let witness_set = &tx.witness_set;
// 1. Commitments or nullifiers are non empty
ensure!(
!message.new_commitments.is_empty() || !message.new_nullifiers.is_empty(),
NssaError::InvalidInput(
"Empty commitments and empty nullifiers found in message".into(),
)
);
// 2. Check there are no duplicate account_ids in the public_account_ids list.
ensure!(
n_unique(&message.public_account_ids) == message.public_account_ids.len(),
NssaError::InvalidInput("Duplicate account_ids found in message".into())
);
// Check there are no duplicate nullifiers in the new_nullifiers list
ensure!(
n_unique(&message.new_nullifiers) == message.new_nullifiers.len(),
NssaError::InvalidInput("Duplicate nullifiers found in message".into())
);
// Check there are no duplicate commitments in the new_commitments list
ensure!(
n_unique(&message.new_commitments) == message.new_commitments.len(),
NssaError::InvalidInput("Duplicate commitments found in message".into())
);
// 3. Nonce checks and Valid signatures
// Check exactly one nonce is provided for each signature
ensure!(
message.nonces.len() == witness_set.signatures_and_public_keys.len(),
NssaError::InvalidInput(
"Mismatch between number of nonces and signatures/public keys".into(),
)
);
// Check the signatures are valid
ensure!(
witness_set.signatures_are_valid_for(message),
NssaError::InvalidInput("Invalid signature for given message and public key".into())
);
let signer_account_ids = tx.signer_account_ids();
// Check nonces corresponds to the current nonces on the public state.
for (account_id, nonce) in signer_account_ids.iter().zip(&message.nonces) {
let current_nonce = state.get_account_by_id(*account_id).nonce;
ensure!(
current_nonce == *nonce,
NssaError::InvalidInput("Nonce mismatch".into())
);
}
// Verify validity window
ensure!(
message.block_validity_window.is_valid_for(block_id)
&& message.timestamp_validity_window.is_valid_for(timestamp),
NssaError::OutOfValidityWindow
);
// Build pre_states for proof verification
let public_pre_states: Vec<_> = message
.public_account_ids
.iter()
.map(|account_id| {
AccountWithMetadata::new(
state.get_account_by_id(*account_id),
signer_account_ids.contains(account_id),
*account_id,
)
})
.collect();
// 4. Proof verification
check_privacy_preserving_circuit_proof_is_valid(
&witness_set.proof,
&public_pre_states,
message,
)?;
// 5. Commitment freshness
state.check_commitments_are_new(&message.new_commitments)?;
// 6. Nullifier uniqueness
state.check_nullifiers_are_valid(&message.new_nullifiers)?;
let public_diff = message
.public_account_ids
.iter()
.copied()
.zip(message.public_post_states.clone())
.collect();
let new_nullifiers = message
.new_nullifiers
.iter()
.copied()
.map(|(nullifier, _)| nullifier)
.collect();
Ok(Self(StateDiff {
signer_account_ids,
public_diff,
new_commitments: message.new_commitments.clone(),
new_nullifiers,
program: None,
}))
}
pub fn from_program_deployment_transaction(
tx: &ProgramDeploymentTransaction,
state: &V03State,
) -> Result<Self, NssaError> {
// TODO: remove clone
let program = Program::new(tx.message.bytecode.clone())?;
if state.programs().contains_key(&program.id()) {
return Err(NssaError::ProgramAlreadyExists);
}
Ok(Self(StateDiff {
signer_account_ids: vec![],
public_diff: HashMap::new(),
new_commitments: vec![],
new_nullifiers: vec![],
program: Some(program),
}))
}
/// Returns the public account changes produced by this transaction.
///
/// Used by callers (e.g. the sequencer) to inspect the diff before committing it, for example
/// to enforce that system accounts are not modified by user transactions.
#[must_use]
pub fn public_diff(&self) -> HashMap<AccountId, Account> {
self.0.public_diff.clone()
}
pub(crate) fn into_state_diff(self) -> StateDiff {
self.0
}
}
fn check_privacy_preserving_circuit_proof_is_valid(
proof: &Proof,
public_pre_states: &[AccountWithMetadata],
message: &Message,
) -> Result<(), NssaError> {
let output = PrivacyPreservingCircuitOutput {
public_pre_states: public_pre_states.to_vec(),
public_post_states: message.public_post_states.clone(),
ciphertexts: message
.encrypted_private_post_states
.iter()
.cloned()
.map(|value| value.ciphertext)
.collect(),
new_commitments: message.new_commitments.clone(),
new_nullifiers: message.new_nullifiers.clone(),
block_validity_window: message.block_validity_window,
timestamp_validity_window: message.timestamp_validity_window,
};
proof
.is_valid_for(&output)
.then_some(())
.ok_or(NssaError::InvalidPrivacyPreservingProof)
}
fn n_unique<T: Eq + Hash>(data: &[T]) -> usize {
let set: HashSet<&T> = data.iter().collect();
set.len()
}
#[cfg(test)]
mod tests {
use nssa_core::account::{AccountId, Nonce};
use crate::{
PrivateKey, PublicKey, V03State,
program::Program,
public_transaction::{Message, WitnessSet},
validated_state_diff::ValidatedStateDiff,
};
/// Demonstrates the authorization-injection vulnerability:
/// two malicious programs (injector + launderer) drain a victim's balance
/// without the victim signing anything.
///
/// Attack flow:
/// Transaction (attacker signs) → P1 (`malicious_injector`)
/// → injects `victim(is_authorized=true)` into chained call `pre_states` for P2
/// P2 (`malicious_launderer`)
/// → outputs empty pre/post states (victim never checked against authorized set)
/// → `authorized_accounts` for `authenticated_transfer` built from
/// `program_output.pre_states` = {victim} `authenticated_transfer`
/// → `victim.is_authorized=true` passes check ({victim}.contains(victim))
/// → transfer executes.
#[test]
fn malicious_programs_drain_victim_without_signature() {
// p2_id, auth_transfer_id, victim_id_raw, victim_balance, victim_nonce,
// victim_program_owner, recipient_id_raw, amount.
// Primitives only — AccountId/Account cannot round-trip through instruction_data
// via risc0_zkvm::serde (SerializeDisplay issue).
type InjectorInstruction = (
nssa_core::program::ProgramId, // p2_id
nssa_core::program::ProgramId, // auth_transfer_id
[u8; 32], // victim_id_raw
u128, // victim_balance
u128, // victim_nonce
nssa_core::program::ProgramId, // victim_program_owner
[u8; 32], // recipient_id_raw
u128, // amount
);
let attacker_key = PrivateKey::try_new([10; 32]).unwrap();
let attacker_id = AccountId::from(&PublicKey::new_from_private_key(&attacker_key));
let victim_key = PrivateKey::try_new([20; 32]).unwrap();
let victim_id = AccountId::from(&PublicKey::new_from_private_key(&victim_key));
let recipient_id = AccountId::new([42; 32]);
let victim_balance = 5_000_u128;
let mut state = V03State::new_with_genesis_accounts(
&[
(attacker_id, 100),
(victim_id, victim_balance),
(recipient_id, 0),
],
vec![],
0,
);
state.insert_program(Program::malicious_injector());
state.insert_program(Program::malicious_launderer());
// Read victim state from chain, exactly as the attacker would.
let victim_account = state.get_account_by_id(victim_id);
let instruction: InjectorInstruction = (
Program::malicious_launderer().id(),
Program::authenticated_transfer_program().id(),
*victim_id.value(),
victim_account.balance,
victim_account.nonce.0,
victim_account.program_owner,
*recipient_id.value(),
victim_balance,
);
let message = Message::try_new(
Program::malicious_injector().id(),
vec![attacker_id],
vec![Nonce(0)],
instruction,
)
.unwrap();
let witness_set = WitnessSet::for_message(&message, &[&attacker_key]);
let tx = crate::PublicTransaction::new(message, witness_set);
let result = ValidatedStateDiff::from_public_transaction(&tx, &state, 1, 0);
assert!(
result.is_err(),
"attack transaction should be rejected by the fixed validator"
);
// Confirm the victim's balance is untouched.
let victim_balance_after = state.get_account_by_id(victim_id).balance;
let recipient_balance_after = state.get_account_by_id(recipient_id).balance;
assert_eq!(
victim_balance_after, victim_balance,
"victim balance should be unchanged"
);
assert_eq!(
recipient_balance_after, 0,
"recipient should receive nothing"
);
}
}