fix: verify caller_program_id in program output

This commit is contained in:
Moudy 2026-04-07 19:03:06 +02:00
parent 495680e2ea
commit 7d465dded7
17 changed files with 109 additions and 17 deletions

View File

@ -20,7 +20,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: greeting,
},
@ -54,6 +54,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -20,7 +20,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: greeting,
},
@ -61,6 +61,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -67,7 +67,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: (function_id, data),
},
@ -87,5 +87,5 @@ fn main() {
// WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be
// called to commit the output.
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
ProgramOutput::new(self_program_id, caller_program_id, instruction_words, pre_states, post_states).write();
}

View File

@ -28,7 +28,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: (),
},
@ -59,6 +59,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -34,7 +34,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: (),
},
@ -72,6 +72,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -285,6 +285,9 @@ pub struct InvalidWindow;
pub struct ProgramOutput {
/// The program ID of the program that produced this output.
pub self_program_id: ProgramId,
/// The program ID of the caller that invoked this program via a chained call,
/// or `None` if this is a top-level call.
pub caller_program_id: Option<ProgramId>,
/// The instruction data the program received to produce this output.
pub instruction_data: InstructionData,
/// The account pre states the program received to produce this output.
@ -302,12 +305,14 @@ pub struct ProgramOutput {
impl ProgramOutput {
pub const fn new(
self_program_id: ProgramId,
caller_program_id: Option<ProgramId>,
instruction_data: InstructionData,
pre_states: Vec<AccountWithMetadata>,
post_states: Vec<AccountPostState>,
) -> Self {
Self {
self_program_id,
caller_program_id,
instruction_data,
pre_states,
post_states,
@ -630,7 +635,7 @@ mod tests {
#[test]
fn program_output_try_with_block_validity_window_range() {
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.try_with_block_validity_window(10_u64..100)
.unwrap();
assert_eq!(output.block_validity_window.start(), Some(10));
@ -639,7 +644,7 @@ mod tests {
#[test]
fn program_output_with_block_validity_window_range_from() {
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.with_block_validity_window(10_u64..);
assert_eq!(output.block_validity_window.start(), Some(10));
assert_eq!(output.block_validity_window.end(), None);
@ -647,7 +652,7 @@ mod tests {
#[test]
fn program_output_with_block_validity_window_range_to() {
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.with_block_validity_window(..100_u64);
assert_eq!(output.block_validity_window.start(), None);
assert_eq!(output.block_validity_window.end(), Some(100));
@ -655,7 +660,7 @@ mod tests {
#[test]
fn program_output_try_with_block_validity_window_empty_range_fails() {
let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.try_with_block_validity_window(5_u64..5);
assert!(result.is_err());
}

View File

@ -345,6 +345,13 @@ mod tests {
Self::new(MALICIOUS_SELF_PROGRAM_ID_ELF.to_vec())
.expect("malicious_self_program_id must be a valid Risc0 program")
}
#[must_use]
pub fn malicious_caller_program_id() -> Self {
use test_program_methods::MALICIOUS_CALLER_PROGRAM_ID_ELF;
Self::new(MALICIOUS_CALLER_PROGRAM_ID_ELF.to_vec())
.expect("malicious_caller_program_id must be a valid Risc0 program")
}
}
#[test]

View File

@ -193,6 +193,12 @@ impl PublicTransaction {
NssaError::InvalidProgramBehavior
);
// Verify that the program output's caller_program_id matches the actual caller.
ensure!(
program_output.caller_program_id == caller_program_id,
NssaError::InvalidProgramBehavior
);
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
ensure!(

View File

@ -385,6 +385,7 @@ pub mod tests {
self.insert_program(Program::flash_swap_initiator());
self.insert_program(Program::flash_swap_callback());
self.insert_program(Program::malicious_self_program_id());
self.insert_program(Program::malicious_caller_program_id());
self
}
@ -3736,4 +3737,25 @@ pub mod tests {
"program with wrong self_program_id in output should be rejected"
);
}
#[test]
fn malicious_caller_program_id_rejected_in_public_execution() {
let program = Program::malicious_caller_program_id();
let acc_id = AccountId::new([99; 32]);
let account = Account::default();
let mut state = V03State::new_with_genesis_accounts(&[], &[]).with_test_programs();
state.force_insert_account(acc_id, account);
let message =
public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
let tx = PublicTransaction::new(message, witness_set);
let result = state.transition_from_public_transaction(&tx, 1, 0);
assert!(
result.is_err(),
"program with spoofed caller_program_id in output should be rejected"
);
}
}

View File

@ -15,7 +15,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction,
},
@ -156,6 +156,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states_clone,
post_states,

View File

@ -5,7 +5,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction,
},
@ -60,6 +60,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states_clone,
post_states,

View File

@ -68,7 +68,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: balance_to_move,
},
@ -86,5 +86,5 @@ fn main() {
_ => panic!("invalid params"),
};
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
ProgramOutput::new(self_program_id, caller_program_id, instruction_words, pre_states, post_states).write();
}

View File

@ -47,7 +47,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: solution,
},
@ -82,6 +82,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pinata, winner],
vec![

View File

@ -53,7 +53,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction: solution,
},
@ -100,6 +100,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![
pinata_definition,

View File

@ -114,6 +114,15 @@ impl ExecutionState {
"Program output self_program_id does not match chained call program_id"
);
// Verify that the program output's caller_program_id matches the actual caller.
// This prevents a malicious user from privately executing an internal function
// by spoofing caller_program_id (e.g. passing caller_program_id = self_program_id
// to bypass access control checks).
assert_eq!(
program_output.caller_program_id, caller_program_id,
"Program output caller_program_id does not match actual caller"
);
// Check that the program is well behaved.
// See the # Programs section for the definition of the `validate_execution` method.
let execution_valid = validate_execution(

View File

@ -13,7 +13,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _,
caller_program_id,
pre_states,
instruction,
},
@ -85,6 +85,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states_clone,
post_states,

View File

@ -0,0 +1,34 @@
use nssa_core::program::{
AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs,
};
type Instruction = ();
fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _, // ignore the actual caller
pre_states,
instruction: (),
},
instruction_words,
) = read_nssa_inputs::<Instruction>();
let post_states = pre_states
.iter()
.map(|a| AccountPostState::new(a.account.clone()))
.collect();
// Deliberately output wrong caller_program_id.
// A real caller_program_id is None for a top-level call, so we spoof Some(DEFAULT_PROGRAM_ID)
// to simulate a program claiming it was invoked by another program when it was not.
ProgramOutput::new(
self_program_id,
Some(DEFAULT_PROGRAM_ID), // WRONG: should be None for a top-level call
instruction_words,
pre_states,
post_states,
)
.write();
}