diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world.rs b/examples/program_deployment/methods/guest/src/bin/hello_world.rs index caae793d..3e91db0e 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world.rs @@ -20,7 +20,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: greeting, }, @@ -54,6 +54,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs b/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs index 25825bf5..70dfa2ae 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world_with_authorization.rs @@ -20,7 +20,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: greeting, }, @@ -61,6 +61,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs b/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs index 2adc1ebe..a45a6cd2 100644 --- a/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs +++ b/examples/program_deployment/methods/guest/src/bin/hello_world_with_move_function.rs @@ -67,7 +67,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: (function_id, data), }, @@ -87,5 +87,5 @@ fn main() { // WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be // called to commit the output. - ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); + ProgramOutput::new(self_program_id, caller_program_id, instruction_words, pre_states, post_states).write(); } diff --git a/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs b/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs index 6e8ead22..716e5c29 100644 --- a/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs +++ b/examples/program_deployment/methods/guest/src/bin/simple_tail_call.rs @@ -28,7 +28,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: (), }, @@ -59,6 +59,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs b/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs index 85023ffd..5ec9aaab 100644 --- a/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs +++ b/examples/program_deployment/methods/guest/src/bin/tail_call_with_pda.rs @@ -34,7 +34,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: (), }, @@ -72,6 +72,7 @@ fn main() { // called to commit the output. ProgramOutput::new( self_program_id, + caller_program_id, instruction_data, vec![pre_state], vec![post_state], diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index 3d98aed2..a08fb2b4 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -285,6 +285,9 @@ pub struct InvalidWindow; pub struct ProgramOutput { /// The program ID of the program that produced this output. pub self_program_id: ProgramId, + /// The program ID of the caller that invoked this program via a chained call, + /// or `None` if this is a top-level call. + pub caller_program_id: Option, /// The instruction data the program received to produce this output. pub instruction_data: InstructionData, /// The account pre states the program received to produce this output. @@ -302,12 +305,14 @@ pub struct ProgramOutput { impl ProgramOutput { pub const fn new( self_program_id: ProgramId, + caller_program_id: Option, instruction_data: InstructionData, pre_states: Vec, post_states: Vec, ) -> Self { Self { self_program_id, + caller_program_id, instruction_data, pre_states, post_states, @@ -630,7 +635,7 @@ mod tests { #[test] fn program_output_try_with_block_validity_window_range() { - let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .try_with_block_validity_window(10_u64..100) .unwrap(); assert_eq!(output.block_validity_window.start(), Some(10)); @@ -639,7 +644,7 @@ mod tests { #[test] fn program_output_with_block_validity_window_range_from() { - let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .with_block_validity_window(10_u64..); assert_eq!(output.block_validity_window.start(), Some(10)); assert_eq!(output.block_validity_window.end(), None); @@ -647,7 +652,7 @@ mod tests { #[test] fn program_output_with_block_validity_window_range_to() { - let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .with_block_validity_window(..100_u64); assert_eq!(output.block_validity_window.start(), None); assert_eq!(output.block_validity_window.end(), Some(100)); @@ -655,7 +660,7 @@ mod tests { #[test] fn program_output_try_with_block_validity_window_empty_range_fails() { - let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![]) + let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![]) .try_with_block_validity_window(5_u64..5); assert!(result.is_err()); } diff --git a/nssa/src/program.rs b/nssa/src/program.rs index 27badcf9..b3c4b510 100644 --- a/nssa/src/program.rs +++ b/nssa/src/program.rs @@ -345,6 +345,13 @@ mod tests { Self::new(MALICIOUS_SELF_PROGRAM_ID_ELF.to_vec()) .expect("malicious_self_program_id must be a valid Risc0 program") } + + #[must_use] + pub fn malicious_caller_program_id() -> Self { + use test_program_methods::MALICIOUS_CALLER_PROGRAM_ID_ELF; + Self::new(MALICIOUS_CALLER_PROGRAM_ID_ELF.to_vec()) + .expect("malicious_caller_program_id must be a valid Risc0 program") + } } #[test] diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 23222a55..36d13b65 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -193,6 +193,12 @@ impl PublicTransaction { NssaError::InvalidProgramBehavior ); + // Verify that the program output's caller_program_id matches the actual caller. + ensure!( + program_output.caller_program_id == caller_program_id, + NssaError::InvalidProgramBehavior + ); + // Verify execution corresponds to a well-behaved program. // See the # Programs section for the definition of the `validate_execution` method. ensure!( diff --git a/nssa/src/state.rs b/nssa/src/state.rs index c5cdae4b..e733540f 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -385,6 +385,7 @@ pub mod tests { self.insert_program(Program::flash_swap_initiator()); self.insert_program(Program::flash_swap_callback()); self.insert_program(Program::malicious_self_program_id()); + self.insert_program(Program::malicious_caller_program_id()); self } @@ -3736,4 +3737,25 @@ pub mod tests { "program with wrong self_program_id in output should be rejected" ); } + + #[test] + fn malicious_caller_program_id_rejected_in_public_execution() { + let program = Program::malicious_caller_program_id(); + let acc_id = AccountId::new([99; 32]); + let account = Account::default(); + + let mut state = V03State::new_with_genesis_accounts(&[], &[]).with_test_programs(); + state.force_insert_account(acc_id, account); + + let message = + public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap(); + let witness_set = public_transaction::WitnessSet::for_message(&message, &[]); + let tx = PublicTransaction::new(message, witness_set); + + let result = state.transition_from_public_transaction(&tx, 1, 0); + assert!( + result.is_err(), + "program with spoofed caller_program_id in output should be rejected" + ); + } } diff --git a/program_methods/guest/src/bin/amm.rs b/program_methods/guest/src/bin/amm.rs index 90f7e06f..bce76c63 100644 --- a/program_methods/guest/src/bin/amm.rs +++ b/program_methods/guest/src/bin/amm.rs @@ -15,7 +15,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction, }, @@ -156,6 +156,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, pre_states_clone, post_states, diff --git a/program_methods/guest/src/bin/associated_token_account.rs b/program_methods/guest/src/bin/associated_token_account.rs index 2dd074d9..9b155d7f 100644 --- a/program_methods/guest/src/bin/associated_token_account.rs +++ b/program_methods/guest/src/bin/associated_token_account.rs @@ -5,7 +5,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction, }, @@ -60,6 +60,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, pre_states_clone, post_states, diff --git a/program_methods/guest/src/bin/authenticated_transfer.rs b/program_methods/guest/src/bin/authenticated_transfer.rs index 302c1620..3ddbd840 100644 --- a/program_methods/guest/src/bin/authenticated_transfer.rs +++ b/program_methods/guest/src/bin/authenticated_transfer.rs @@ -68,7 +68,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: balance_to_move, }, @@ -86,5 +86,5 @@ fn main() { _ => panic!("invalid params"), }; - ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write(); + ProgramOutput::new(self_program_id, caller_program_id, instruction_words, pre_states, post_states).write(); } diff --git a/program_methods/guest/src/bin/pinata.rs b/program_methods/guest/src/bin/pinata.rs index 1c1e2e94..dcc76397 100644 --- a/program_methods/guest/src/bin/pinata.rs +++ b/program_methods/guest/src/bin/pinata.rs @@ -47,7 +47,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: solution, }, @@ -82,6 +82,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![pinata, winner], vec![ diff --git a/program_methods/guest/src/bin/pinata_token.rs b/program_methods/guest/src/bin/pinata_token.rs index 2e09489c..1f7ad9da 100644 --- a/program_methods/guest/src/bin/pinata_token.rs +++ b/program_methods/guest/src/bin/pinata_token.rs @@ -53,7 +53,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction: solution, }, @@ -100,6 +100,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, vec![ pinata_definition, diff --git a/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/program_methods/guest/src/bin/privacy_preserving_circuit.rs index 48d4b3b7..1d091e1c 100644 --- a/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -114,6 +114,15 @@ impl ExecutionState { "Program output self_program_id does not match chained call program_id" ); + // Verify that the program output's caller_program_id matches the actual caller. + // This prevents a malicious user from privately executing an internal function + // by spoofing caller_program_id (e.g. passing caller_program_id = self_program_id + // to bypass access control checks). + assert_eq!( + program_output.caller_program_id, caller_program_id, + "Program output caller_program_id does not match actual caller" + ); + // Check that the program is well behaved. // See the # Programs section for the definition of the `validate_execution` method. let execution_valid = validate_execution( diff --git a/program_methods/guest/src/bin/token.rs b/program_methods/guest/src/bin/token.rs index c3500da6..68205d77 100644 --- a/program_methods/guest/src/bin/token.rs +++ b/program_methods/guest/src/bin/token.rs @@ -13,7 +13,7 @@ fn main() { let ( ProgramInput { self_program_id, - caller_program_id: _, + caller_program_id, pre_states, instruction, }, @@ -85,6 +85,7 @@ fn main() { ProgramOutput::new( self_program_id, + caller_program_id, instruction_words, pre_states_clone, post_states, diff --git a/test_program_methods/guest/src/bin/malicious_caller_program_id.rs b/test_program_methods/guest/src/bin/malicious_caller_program_id.rs new file mode 100644 index 00000000..2326190e --- /dev/null +++ b/test_program_methods/guest/src/bin/malicious_caller_program_id.rs @@ -0,0 +1,34 @@ +use nssa_core::program::{ + AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs, +}; + +type Instruction = (); + +fn main() { + let ( + ProgramInput { + self_program_id, + caller_program_id: _, // ignore the actual caller + pre_states, + instruction: (), + }, + instruction_words, + ) = read_nssa_inputs::(); + + let post_states = pre_states + .iter() + .map(|a| AccountPostState::new(a.account.clone())) + .collect(); + + // Deliberately output wrong caller_program_id. + // A real caller_program_id is None for a top-level call, so we spoof Some(DEFAULT_PROGRAM_ID) + // to simulate a program claiming it was invoked by another program when it was not. + ProgramOutput::new( + self_program_id, + Some(DEFAULT_PROGRAM_ID), // WRONG: should be None for a top-level call + instruction_words, + pre_states, + post_states, + ) + .write(); +}