Merge pull request #428 from logos-blockchain/moudy/feat-caller-program-id-and-flash-swap

Add caller_program_id to ProgramInput and flash swap demo
This commit is contained in:
Daniil Polyakov 2026-04-07 22:48:56 +03:00 committed by GitHub
commit 8700e404da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 842 additions and 19 deletions

1
Cargo.lock generated
View File

@ -7846,6 +7846,7 @@ dependencies = [
"clock_core",
"nssa_core",
"risc0-zkvm",
"serde",
]
[[package]]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -20,6 +20,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: greeting,
},
@ -53,6 +54,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -20,6 +20,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: greeting,
},
@ -60,6 +61,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -67,6 +67,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (function_id, data),
},
@ -86,5 +87,12 @@ fn main() {
// WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be
// called to commit the output.
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states,
post_states,
)
.write();
}

View File

@ -28,6 +28,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (),
},
@ -58,6 +59,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -34,6 +34,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (),
},
@ -71,6 +72,7 @@ fn main() {
// called to commit the output.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
vec![pre_state],
vec![post_state],

View File

@ -17,6 +17,7 @@ pub type ProgramId = [u32; 8];
pub type InstructionData = Vec<u32>;
pub struct ProgramInput<T> {
pub self_program_id: ProgramId,
pub caller_program_id: Option<ProgramId>,
pub pre_states: Vec<AccountWithMetadata>,
pub instruction: T,
}
@ -284,6 +285,9 @@ pub struct InvalidWindow;
pub struct ProgramOutput {
/// The program ID of the program that produced this output.
pub self_program_id: ProgramId,
/// The program ID of the caller that invoked this program via a chained call,
/// or `None` if this is a top-level call.
pub caller_program_id: Option<ProgramId>,
/// The instruction data the program received to produce this output.
pub instruction_data: InstructionData,
/// The account pre states the program received to produce this output.
@ -301,12 +305,14 @@ pub struct ProgramOutput {
impl ProgramOutput {
pub const fn new(
self_program_id: ProgramId,
caller_program_id: Option<ProgramId>,
instruction_data: InstructionData,
pre_states: Vec<AccountWithMetadata>,
post_states: Vec<AccountPostState>,
) -> Self {
Self {
self_program_id,
caller_program_id,
instruction_data,
pre_states,
post_states,
@ -421,12 +427,14 @@ pub fn compute_authorized_pdas(
#[must_use]
pub fn read_nssa_inputs<T: DeserializeOwned>() -> (ProgramInput<T>, InstructionData) {
let self_program_id: ProgramId = env::read();
let caller_program_id: Option<ProgramId> = env::read();
let pre_states: Vec<AccountWithMetadata> = env::read();
let instruction_words: InstructionData = env::read();
let instruction = T::deserialize(&mut Deserializer::new(instruction_words.as_ref())).unwrap();
(
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction,
},
@ -627,7 +635,7 @@ mod tests {
#[test]
fn program_output_try_with_block_validity_window_range() {
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.try_with_block_validity_window(10_u64..100)
.unwrap();
assert_eq!(output.block_validity_window.start(), Some(10));
@ -636,7 +644,7 @@ mod tests {
#[test]
fn program_output_with_block_validity_window_range_from() {
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.with_block_validity_window(10_u64..);
assert_eq!(output.block_validity_window.start(), Some(10));
assert_eq!(output.block_validity_window.end(), None);
@ -644,7 +652,7 @@ mod tests {
#[test]
fn program_output_with_block_validity_window_range_to() {
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.with_block_validity_window(..100_u64);
assert_eq!(output.block_validity_window.start(), None);
assert_eq!(output.block_validity_window.end(), Some(100));
@ -652,7 +660,7 @@ mod tests {
#[test]
fn program_output_try_with_block_validity_window_empty_range_fails() {
let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
.try_with_block_validity_window(5_u64..5);
assert!(result.is_err());
}

View File

@ -87,15 +87,16 @@ pub fn execute_and_prove(
pda_seeds: vec![],
};
let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program)]);
let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program, None)]);
let mut chain_calls_counter = 0;
while let Some((chained_call, program)) = chained_calls.pop_front() {
while let Some((chained_call, program, caller_program_id)) = chained_calls.pop_front() {
if chain_calls_counter >= MAX_NUMBER_CHAINED_CALLS {
return Err(NssaError::MaxChainedCallsDepthExceeded);
}
let inner_receipt = execute_and_prove_program(
program,
caller_program_id,
&chained_call.pre_states,
&chained_call.instruction_data,
)?;
@ -115,7 +116,7 @@ pub fn execute_and_prove(
let next_program = dependencies
.get(&new_call.program_id)
.ok_or(NssaError::InvalidProgramBehavior)?;
chained_calls.push_front((new_call, next_program));
chained_calls.push_front((new_call, next_program, Some(chained_call.program_id)));
}
chain_calls_counter = chain_calls_counter
@ -153,12 +154,19 @@ pub fn execute_and_prove(
fn execute_and_prove_program(
program: &Program,
caller_program_id: Option<ProgramId>,
pre_states: &[AccountWithMetadata],
instruction_data: &InstructionData,
) -> Result<Receipt, NssaError> {
// Write inputs to the program
let mut env_builder = ExecutorEnv::builder();
Program::write_inputs(program.id(), pre_states, instruction_data, &mut env_builder)?;
Program::write_inputs(
program.id(),
caller_program_id,
pre_states,
instruction_data,
&mut env_builder,
)?;
let env = env_builder.build().unwrap();
// Prove the program

View File

@ -53,13 +53,20 @@ impl Program {
pub(crate) fn execute(
&self,
caller_program_id: Option<ProgramId>,
pre_states: &[AccountWithMetadata],
instruction_data: &InstructionData,
) -> Result<ProgramOutput, NssaError> {
// Write inputs to the program
let mut env_builder = ExecutorEnv::builder();
env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION));
Self::write_inputs(self.id, pre_states, instruction_data, &mut env_builder)?;
Self::write_inputs(
self.id,
caller_program_id,
pre_states,
instruction_data,
&mut env_builder,
)?;
let env = env_builder.build().unwrap();
// Execute the program (without proving)
@ -80,6 +87,7 @@ impl Program {
/// Writes inputs to `env_builder` in the order expected by the programs.
pub(crate) fn write_inputs(
program_id: ProgramId,
caller_program_id: Option<ProgramId>,
pre_states: &[AccountWithMetadata],
instruction_data: &[u32],
env_builder: &mut ExecutorEnvBuilder,
@ -87,6 +95,9 @@ impl Program {
env_builder
.write(&program_id)
.map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?;
env_builder
.write(&caller_program_id)
.map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?;
let pre_states = pre_states.to_vec();
env_builder
.write(&pre_states)
@ -320,6 +331,34 @@ mod tests {
Self::new(VALIDITY_WINDOW_CHAIN_CALLER_ELF.to_vec()).unwrap()
}
#[must_use]
pub fn flash_swap_initiator() -> Self {
use test_program_methods::FLASH_SWAP_INITIATOR_ELF;
Self::new(FLASH_SWAP_INITIATOR_ELF.to_vec())
.expect("flash_swap_initiator must be a valid Risc0 program")
}
#[must_use]
pub fn flash_swap_callback() -> Self {
use test_program_methods::FLASH_SWAP_CALLBACK_ELF;
Self::new(FLASH_SWAP_CALLBACK_ELF.to_vec())
.expect("flash_swap_callback must be a valid Risc0 program")
}
#[must_use]
pub fn malicious_self_program_id() -> Self {
use test_program_methods::MALICIOUS_SELF_PROGRAM_ID_ELF;
Self::new(MALICIOUS_SELF_PROGRAM_ID_ELF.to_vec())
.expect("malicious_self_program_id must be a valid Risc0 program")
}
#[must_use]
pub fn malicious_caller_program_id() -> Self {
use test_program_methods::MALICIOUS_CALLER_PROGRAM_ID_ELF;
Self::new(MALICIOUS_CALLER_PROGRAM_ID_ELF.to_vec())
.expect("malicious_caller_program_id must be a valid Risc0 program")
}
#[must_use]
pub fn time_locked_transfer() -> Self {
use test_program_methods::TIME_LOCKED_TRANSFER_ELF;
@ -358,7 +397,7 @@ mod tests {
..Account::default()
};
let program_output = program
.execute(&[sender, recipient], &instruction_data)
.execute(None, &[sender, recipient], &instruction_data)
.unwrap();
let [sender_post, recipient_post] = program_output.post_states.try_into().unwrap();

View File

@ -400,6 +400,10 @@ pub mod tests {
self.insert_program(Program::claimer());
self.insert_program(Program::changer_claimer());
self.insert_program(Program::validity_window());
self.insert_program(Program::flash_swap_initiator());
self.insert_program(Program::flash_swap_callback());
self.insert_program(Program::malicious_self_program_id());
self.insert_program(Program::malicious_caller_program_id());
self.insert_program(Program::time_locked_transfer());
self.insert_program(Program::pinata_cooldown());
self
@ -478,6 +482,28 @@ pub mod tests {
}
}
// ── Flash Swap types (mirrors of guest types for host-side serialisation) ──
#[derive(serde::Serialize, serde::Deserialize)]
struct CallbackInstruction {
return_funds: bool,
token_program_id: ProgramId,
amount: u128,
}
#[derive(serde::Serialize, serde::Deserialize)]
enum FlashSwapInstruction {
Initiate {
token_program_id: ProgramId,
callback_program_id: ProgramId,
amount_out: u128,
callback_instruction_data: Vec<u32>,
},
InvariantCheck {
min_vault_balance: u128,
},
}
fn transfer_transaction(
from: AccountId,
from_key: &PrivateKey,
@ -497,6 +523,23 @@ pub mod tests {
PublicTransaction::new(message, witness_set)
}
fn build_flash_swap_tx(
initiator: &Program,
vault_id: AccountId,
receiver_id: AccountId,
instruction: FlashSwapInstruction,
) -> PublicTransaction {
let message = public_transaction::Message::try_new(
initiator.id(),
vec![vault_id, receiver_id],
vec![], // no signers — vault is PDA-authorised
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
PublicTransaction::new(message, witness_set)
}
#[test]
fn new_with_genesis() {
let key1 = PrivateKey::try_new([1; 32]).unwrap();
@ -3877,4 +3920,242 @@ pub mod tests {
let state_from_bytes: V03State = borsh::from_slice(&bytes).unwrap();
assert_eq!(state, state_from_bytes);
}
#[test]
fn flash_swap_successful() {
let initiator = Program::flash_swap_initiator();
let callback = Program::flash_swap_callback();
let token = Program::authenticated_transfer_program();
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32])));
let initial_balance: u128 = 1000;
let amount_out: u128 = 100;
let vault_account = Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
};
let receiver_account = Account {
program_owner: token.id(),
balance: 0,
..Account::default()
};
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
state.force_insert_account(vault_id, vault_account);
state.force_insert_account(receiver_id, receiver_account);
// Callback instruction: return funds
let cb_instruction = CallbackInstruction {
return_funds: true,
token_program_id: token.id(),
amount: amount_out,
};
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
let instruction = FlashSwapInstruction::Initiate {
token_program_id: token.id(),
callback_program_id: callback.id(),
amount_out,
callback_instruction_data: cb_data,
};
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
let result = state.transition_from_public_transaction(&tx, 1, 0);
assert!(result.is_ok(), "flash swap should succeed: {result:?}");
// Vault balance restored, receiver back to 0
assert_eq!(state.get_account_by_id(vault_id).balance, initial_balance);
assert_eq!(state.get_account_by_id(receiver_id).balance, 0);
}
#[test]
fn flash_swap_callback_keeps_funds_rollback() {
let initiator = Program::flash_swap_initiator();
let callback = Program::flash_swap_callback();
let token = Program::authenticated_transfer_program();
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32])));
let initial_balance: u128 = 1000;
let amount_out: u128 = 100;
let vault_account = Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
};
let receiver_account = Account {
program_owner: token.id(),
balance: 0,
..Account::default()
};
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
state.force_insert_account(vault_id, vault_account);
state.force_insert_account(receiver_id, receiver_account);
// Callback instruction: do NOT return funds
let cb_instruction = CallbackInstruction {
return_funds: false,
token_program_id: token.id(),
amount: amount_out,
};
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
let instruction = FlashSwapInstruction::Initiate {
token_program_id: token.id(),
callback_program_id: callback.id(),
amount_out,
callback_instruction_data: cb_data,
};
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
let result = state.transition_from_public_transaction(&tx, 1, 0);
// Invariant check fails → entire tx rolls back
assert!(
result.is_err(),
"flash swap should fail when callback keeps funds"
);
// State unchanged (rollback)
assert_eq!(state.get_account_by_id(vault_id).balance, initial_balance);
assert_eq!(state.get_account_by_id(receiver_id).balance, 0);
}
#[test]
fn flash_swap_self_call_targets_correct_program() {
// Zero-amount flash swap: the invariant self-call still runs and succeeds
// because vault balance doesn't decrease.
let initiator = Program::flash_swap_initiator();
let callback = Program::flash_swap_callback();
let token = Program::authenticated_transfer_program();
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32])));
let initial_balance: u128 = 1000;
let vault_account = Account {
program_owner: token.id(),
balance: initial_balance,
..Account::default()
};
let receiver_account = Account {
program_owner: token.id(),
balance: 0,
..Account::default()
};
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
state.force_insert_account(vault_id, vault_account);
state.force_insert_account(receiver_id, receiver_account);
let cb_instruction = CallbackInstruction {
return_funds: true,
token_program_id: token.id(),
amount: 0,
};
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
let instruction = FlashSwapInstruction::Initiate {
token_program_id: token.id(),
callback_program_id: callback.id(),
amount_out: 0,
callback_instruction_data: cb_data,
};
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
let result = state.transition_from_public_transaction(&tx, 1, 0);
assert!(
result.is_ok(),
"zero-amount flash swap should succeed: {result:?}"
);
}
#[test]
fn flash_swap_standalone_invariant_check_rejected() {
// Calling InvariantCheck directly (not as a chained self-call) should fail
// because caller_program_id will be None.
let initiator = Program::flash_swap_initiator();
let token = Program::authenticated_transfer_program();
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
let vault_account = Account {
program_owner: token.id(),
balance: 1000,
..Account::default()
};
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
state.force_insert_account(vault_id, vault_account);
let instruction = FlashSwapInstruction::InvariantCheck {
min_vault_balance: 1000,
};
let message = public_transaction::Message::try_new(
initiator.id(),
vec![vault_id],
vec![],
instruction,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
let tx = PublicTransaction::new(message, witness_set);
let result = state.transition_from_public_transaction(&tx, 1, 0);
assert!(
result.is_err(),
"standalone InvariantCheck should be rejected (caller_program_id is None)"
);
}
#[test]
fn malicious_self_program_id_rejected_in_public_execution() {
let program = Program::malicious_self_program_id();
let acc_id = AccountId::new([99; 32]);
let account = Account::default();
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
state.force_insert_account(acc_id, account);
let message =
public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
let tx = PublicTransaction::new(message, witness_set);
let result = state.transition_from_public_transaction(&tx, 1, 0);
assert!(
result.is_err(),
"program with wrong self_program_id in output should be rejected"
);
}
#[test]
fn malicious_caller_program_id_rejected_in_public_execution() {
let program = Program::malicious_caller_program_id();
let acc_id = AccountId::new([99; 32]);
let account = Account::default();
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
state.force_insert_account(acc_id, account);
let message =
public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
let tx = PublicTransaction::new(message, witness_set);
let result = state.transition_from_public_transaction(&tx, 1, 0);
assert!(
result.is_err(),
"program with spoofed caller_program_id in output should be rejected"
);
}
}

View File

@ -118,8 +118,11 @@ impl ValidatedStateDiff {
"Program {:?} pre_states: {:?}, instruction_data: {:?}",
chained_call.program_id, chained_call.pre_states, chained_call.instruction_data
);
let mut program_output =
program.execute(&chained_call.pre_states, &chained_call.instruction_data)?;
let mut program_output = program.execute(
caller_program_id,
&chained_call.pre_states,
&chained_call.instruction_data,
)?;
debug!(
"Program {:?} output: {:?}",
chained_call.program_id, program_output
@ -159,6 +162,12 @@ impl ValidatedStateDiff {
NssaError::InvalidProgramBehavior
);
// Verify that the program output's caller_program_id matches the actual caller.
ensure!(
program_output.caller_program_id == caller_program_id,
NssaError::InvalidProgramBehavior
);
// Verify execution corresponds to a well-behaved program.
// See the # Programs section for the definition of the `validate_execution` method.
ensure!(

View File

@ -15,6 +15,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction,
},
@ -155,6 +156,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states_clone,
post_states,

View File

@ -5,6 +5,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction,
},
@ -59,6 +60,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states_clone,
post_states,

View File

@ -68,6 +68,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: balance_to_move,
},
@ -85,5 +86,12 @@ fn main() {
_ => panic!("invalid params"),
};
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states,
post_states,
)
.write();
}

View File

@ -40,6 +40,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: timestamp,
},
@ -84,6 +85,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre_01, pre_10, pre_50],
vec![post_01, post_10, post_50],

View File

@ -47,6 +47,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: solution,
},
@ -81,6 +82,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pinata, winner],
vec![

View File

@ -53,6 +53,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: solution,
},
@ -99,6 +100,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![
pinata_definition,

View File

@ -114,6 +114,15 @@ impl ExecutionState {
"Program output self_program_id does not match chained call program_id"
);
// Verify that the program output's caller_program_id matches the actual caller.
// This prevents a malicious user from privately executing an internal function
// by spoofing caller_program_id (e.g. passing caller_program_id = self_program_id
// to bypass access control checks).
assert_eq!(
program_output.caller_program_id, caller_program_id,
"Program output caller_program_id does not match actual caller"
);
// Check that the program is well behaved.
// See the # Programs section for the definition of the `validate_execution` method.
let execution_valid = validate_execution(

View File

@ -13,6 +13,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction,
},
@ -84,6 +85,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states_clone,
post_states,

View File

@ -12,3 +12,4 @@ nssa_core.workspace = true
clock_core.workspace = true
risc0-zkvm.workspace = true
serde = { workspace = true, default-features = false }

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: balance_to_burn,
},
@ -22,6 +23,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new(account_post)],

View File

@ -14,6 +14,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (balance, auth_transfer_id, num_chain_calls, pda_seed),
},
@ -57,6 +58,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![sender_pre.clone(), recipient_pre.clone()],
vec![

View File

@ -7,6 +7,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (data_opt, should_claim),
},
@ -36,6 +37,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![post_state],

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (),
},
@ -20,6 +21,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![account_post],

View File

@ -15,6 +15,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (clock_program_id, timestamp),
},
@ -33,7 +34,13 @@ fn main() {
pda_seeds: vec![],
};
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states)
.with_chained_calls(vec![chained_call])
.write();
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states,
post_states,
)
.with_chained_calls(vec![chained_call])
.write();
}

View File

@ -7,6 +7,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: data,
},
@ -25,6 +26,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new_claimed(

View File

@ -9,6 +9,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
..
},
@ -23,6 +24,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![

View File

@ -0,0 +1,94 @@
//! Flash swap callback, the user logic step in the "prep → callback → assert" pattern.
//!
//! # Role
//!
//! This program is called as chained call 2 in the flash swap sequence:
//! 1. Token transfer out (vault → receiver)
//! 2. **This callback** (user logic)
//! 3. Invariant check (assert vault balance restored)
//!
//! In a real flash swap, this would contain the user's arbitrage or other logic.
//! In this test program, it is controlled by `return_funds`:
//!
//! - `return_funds = true`: emits a token transfer (receiver → vault) to return the funds. The
//! invariant check will pass and the transaction will succeed.
//!
//! - `return_funds = false`: emits no transfers. Funds stay with the receiver. The invariant check
//! will fail (vault balance < initial), causing full atomic rollback. This simulates a malicious
//! or buggy callback that does not repay the flash loan.
//!
//! # Note on `caller_program_id`
//!
//! This program does not enforce any access control on `caller_program_id`.
//! It is designed to be called by the flash swap initiator but could in principle be
//! called by any program. In production, a callback would typically verify the caller
//! if it needs to trust the context it is called from.
use nssa_core::program::{
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
read_nssa_inputs,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct CallbackInstruction {
/// If true, return the borrowed funds to the vault (happy path).
/// If false, keep the funds (simulates a malicious callback, triggers rollback).
pub return_funds: bool,
pub token_program_id: ProgramId,
pub amount: u128,
}
fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id, // not enforced in this callback
pre_states,
instruction,
},
instruction_words,
) = read_nssa_inputs::<CallbackInstruction>();
// pre_states[0] = vault (after transfer out), pre_states[1] = receiver (after transfer out)
let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else {
panic!("Callback requires exactly 2 accounts: vault, receiver");
};
let mut chained_calls = Vec::new();
if instruction.return_funds {
// Happy path: return the borrowed funds via a token transfer (receiver → vault).
// The receiver is a PDA of this callback program (seed = [1_u8; 32]).
// Mark the receiver as authorized since it will be PDA-authorized in this chained call.
let mut receiver_authorized = receiver_pre.clone();
receiver_authorized.is_authorized = true;
let transfer_instruction = risc0_zkvm::serde::to_vec(&instruction.amount)
.expect("transfer instruction serialization");
chained_calls.push(ChainedCall {
program_id: instruction.token_program_id,
pre_states: vec![receiver_authorized, vault_pre.clone()],
instruction_data: transfer_instruction,
pda_seeds: vec![PdaSeed::new([1_u8; 32])],
});
}
// Malicious path (return_funds = false): emit no chained calls.
// The vault balance will not be restored, so the invariant check in the initiator
// will panic, rolling back the entire transaction including the initial transfer out.
// The callback itself makes no direct state changes, accounts pass through unchanged.
// All mutations go through the token program via chained calls.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![vault_pre.clone(), receiver_pre.clone()],
vec![
AccountPostState::new(vault_pre.account),
AccountPostState::new(receiver_pre.account),
],
)
.with_chained_calls(chained_calls)
.write();
}

View File

@ -0,0 +1,216 @@
//! Flash swap initiator, demonstrates the "prep → callback → assert" pattern using
//! generalized multi tail-calls with `self_program_id` and `caller_program_id`.
//!
//! # Pattern
//!
//! A flash swap lets a program optimistically transfer tokens out, run arbitrary user
//! logic (the callback), then assert that invariants hold after the callback. The entire
//! sequence is a single atomic transaction: if any step fails, all state changes roll back.
//!
//! # How it works
//!
//! This program handles two instruction variants:
//!
//! - `Initiate` (external): the top-level entrypoint. Emits 3 chained calls:
//! 1. Token transfer out (vault → receiver)
//! 2. User callback (arbitrary logic, e.g. arbitrage)
//! 3. Self-call to `InvariantCheck` (using `self_program_id` to reference itself)
//!
//! - `InvariantCheck` (internal): enforces that the vault balance was restored after the callback.
//! Uses `caller_program_id == Some(self_program_id)` to prevent standalone calls (this is the
//! visibility enforcement mechanism).
//!
//! # What this demonstrates
//!
//! - `self_program_id`: enables a program to chain back to itself (step 3 above)
//! - `caller_program_id`: enables a program to restrict which callers can invoke an instruction
//! - Computed intermediate states: the initiator computes expected intermediate account states from
//! the `pre_states` and amount, keeping the instruction minimal.
//! - Atomic rollback: if the callback doesn't return funds, the invariant check fails, and all
//! state changes from steps 1 and 2 are rolled back automatically.
//!
//! # Tests
//!
//! See `nssa/src/state.rs` for integration tests:
//! - `flash_swap_successful`: full round-trip, funds returned, state unchanged
//! - `flash_swap_callback_keeps_funds_rollback`: callback keeps funds, full rollback
//! - `flash_swap_self_call_targets_correct_program`: zero-amount self-call isolation test
//! - `flash_swap_standalone_invariant_check_rejected`: `caller_program_id` access control
use nssa_core::program::{
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
read_nssa_inputs,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub enum FlashSwapInstruction {
/// External entrypoint: initiate a flash swap.
///
/// Emits 3 chained calls:
/// 1. Token transfer (vault → receiver, `amount_out`)
/// 2. Callback (user logic, e.g. arbitrage)
/// 3. Self-call `InvariantCheck` (verify vault balance did not decrease)
///
/// Intermediate account states are computed inside the program from `pre_states` and
/// `amount_out`.
Initiate {
token_program_id: ProgramId,
callback_program_id: ProgramId,
amount_out: u128,
callback_instruction_data: Vec<u32>,
},
/// Internal: verify the vault invariant holds after callback execution.
///
/// Access control: only callable as a chained call from this program itself.
/// This is enforced by checking `caller_program_id == Some(self_program_id)`.
/// Any attempt to call this instruction as a standalone top-level transaction
/// will be rejected because `caller_program_id` will be `None`.
InvariantCheck { min_vault_balance: u128 },
}
fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction,
},
instruction_words,
) = read_nssa_inputs::<FlashSwapInstruction>();
match instruction {
FlashSwapInstruction::Initiate {
token_program_id,
callback_program_id,
amount_out,
callback_instruction_data,
} => {
let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else {
panic!("Initiate requires exactly 2 accounts: vault, receiver");
};
// Capture initial vault balance, the invariant check will verify it is restored.
let min_vault_balance = vault_pre.account.balance;
// Compute intermediate account states from pre_states and amount_out.
let mut vault_after_transfer = vault_pre.clone();
vault_after_transfer.account.balance = vault_pre
.account
.balance
.checked_sub(amount_out)
.expect("vault has insufficient balance for flash swap");
let mut receiver_after_transfer = receiver_pre.clone();
receiver_after_transfer.account.balance = receiver_pre
.account
.balance
.checked_add(amount_out)
.expect("receiver balance overflow");
let mut vault_after_callback = vault_after_transfer.clone();
vault_after_callback.account.balance = vault_after_transfer
.account
.balance
.checked_add(amount_out)
.expect("vault balance overflow after callback");
// Chained call 1: Token transfer (vault → receiver).
// The vault is a PDA of this initiator program (seed = [0_u8; 32]), so we provide
// the PDA seed to authorize the token program to debit the vault on our behalf.
// Mark the vault as authorized since it will be PDA-authorized in this chained call.
let mut vault_authorized = vault_pre.clone();
vault_authorized.is_authorized = true;
let transfer_instruction =
risc0_zkvm::serde::to_vec(&amount_out).expect("transfer instruction serialization");
let call_1 = ChainedCall {
program_id: token_program_id,
pre_states: vec![vault_authorized, receiver_pre.clone()],
instruction_data: transfer_instruction,
pda_seeds: vec![PdaSeed::new([0_u8; 32])],
};
// Chained call 2: User callback.
// Receives the post-transfer states as its pre_states. The callback may run
// arbitrary logic (arbitrage, etc.) and is expected to return funds to the vault.
let call_2 = ChainedCall {
program_id: callback_program_id,
pre_states: vec![vault_after_transfer, receiver_after_transfer],
instruction_data: callback_instruction_data,
pda_seeds: vec![],
};
// Chained call 3: Self-call to enforce the invariant.
// Uses `self_program_id` to reference this program, the key feature that enables
// the "prep → callback → assert" pattern without a separate checker program.
// If the callback did not return funds, vault_after_callback.balance <
// min_vault_balance and this call will panic, rolling back the entire
// transaction.
let invariant_instruction =
risc0_zkvm::serde::to_vec(&FlashSwapInstruction::InvariantCheck {
min_vault_balance,
})
.expect("invariant instruction serialization");
let call_3 = ChainedCall {
program_id: self_program_id, // self-referential chained call
pre_states: vec![vault_after_callback],
instruction_data: invariant_instruction,
pda_seeds: vec![],
};
// The initiator itself makes no direct state changes.
// All mutations happen inside the chained calls (token transfers).
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![vault_pre.clone(), receiver_pre.clone()],
vec![
AccountPostState::new(vault_pre.account),
AccountPostState::new(receiver_pre.account),
],
)
.with_chained_calls(vec![call_1, call_2, call_3])
.write();
}
FlashSwapInstruction::InvariantCheck { min_vault_balance } => {
// Visibility enforcement: `InvariantCheck` is an internal instruction.
// It must only be called as a chained call from this program itself (via `Initiate`).
// When called as a top-level transaction, `caller_program_id` is `None` → panics.
// When called as a chained call from `Initiate`, `caller_program_id` is
// `Some(self_program_id)` → passes.
assert_eq!(
caller_program_id,
Some(self_program_id),
"InvariantCheck is an internal instruction: must be called by flash_swap_initiator \
via a chained call",
);
let Ok([vault]) = <[_; 1]>::try_from(pre_states) else {
panic!("InvariantCheck requires exactly 1 account: vault");
};
// The core invariant: vault balance must not have decreased.
// If the callback returned funds, this passes. If not, this panics and
// the entire transaction (including the prior token transfer) rolls back.
assert!(
vault.account.balance >= min_vault_balance,
"Flash swap invariant violated: vault balance {} < minimum {}",
vault.account.balance,
min_vault_balance
);
// Pass-through: no state changes in the invariant check step.
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![vault.clone()],
vec![AccountPostState::new(vault.account)],
)
.write();
}
}
}

View File

@ -15,6 +15,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (balance, transfer_program_id),
},
@ -42,6 +43,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![sender.clone(), receiver.clone()],
vec![

View File

@ -0,0 +1,34 @@
use nssa_core::program::{
AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs,
};
type Instruction = ();
fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id: _, // ignore the actual caller
pre_states,
instruction: (),
},
instruction_words,
) = read_nssa_inputs::<Instruction>();
let post_states = pre_states
.iter()
.map(|a| AccountPostState::new(a.account.clone()))
.collect();
// Deliberately output wrong caller_program_id.
// A real caller_program_id is None for a top-level call, so we spoof Some(DEFAULT_PROGRAM_ID)
// to simulate a program claiming it was invoked by another program when it was not.
ProgramOutput::new(
self_program_id,
Some(DEFAULT_PROGRAM_ID), // WRONG: should be None for a top-level call
instruction_words,
pre_states,
post_states,
)
.write();
}

View File

@ -0,0 +1,32 @@
use nssa_core::program::{
AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs,
};
type Instruction = ();
fn main() {
let (
ProgramInput {
self_program_id: _, // ignore the correct ID
caller_program_id,
pre_states,
instruction: (),
},
instruction_words,
) = read_nssa_inputs::<Instruction>();
let post_states = pre_states
.iter()
.map(|a| AccountPostState::new(a.account.clone()))
.collect();
// Deliberately output wrong self_program_id
ProgramOutput::new(
DEFAULT_PROGRAM_ID, // WRONG: should be self_program_id
caller_program_id,
instruction_words,
pre_states,
post_states,
)
.write();
}

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
..
},
@ -25,6 +26,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new(account_post)],

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
..
},
@ -20,6 +21,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre1, pre2],
vec![AccountPostState::new(account_pre1)],

View File

@ -65,6 +65,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: balance_to_move,
},
@ -81,5 +82,12 @@ fn main() {
}
_ => panic!("invalid params"),
};
ProgramOutput::new(self_program_id, instruction_data, pre_states, post_states).write();
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_data,
pre_states,
post_states,
)
.write();
}

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
..
},
@ -22,6 +23,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new(account_post)],

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
..
},
@ -16,5 +17,12 @@ fn main() {
.iter()
.map(|account| AccountPostState::new(account.account.clone()))
.collect();
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
pre_states,
post_states,
)
.write();
}

View File

@ -49,6 +49,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (),
},
@ -102,6 +103,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pinata, winner, clock_pre],
vec![

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
..
},
@ -22,6 +23,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new(account_post)],

View File

@ -6,6 +6,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: balance,
},
@ -29,6 +30,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![sender_pre, receiver_pre],
vec![

View File

@ -19,6 +19,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (amount, deadline),
},
@ -58,6 +59,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![sender_pre, receiver_pre, clock_pre],
vec![

View File

@ -9,6 +9,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (block_validity_window, timestamp_validity_window),
},
@ -23,6 +24,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new(post)],

View File

@ -17,6 +17,7 @@ fn main() {
let (
ProgramInput {
self_program_id,
caller_program_id,
pre_states,
instruction: (block_validity_window, chained_program_id, chained_block_validity_window),
},
@ -40,6 +41,7 @@ fn main() {
ProgramOutput::new(
self_program_id,
caller_program_id,
instruction_words,
vec![pre],
vec![AccountPostState::new(post)],