mirror of
https://github.com/logos-blockchain/lssa.git
synced 2026-04-10 21:23:30 +00:00
Merge pull request #428 from logos-blockchain/moudy/feat-caller-program-id-and-flash-swap
Add caller_program_id to ProgramInput and flash swap demo
This commit is contained in:
commit
8700e404da
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -7846,6 +7846,7 @@ dependencies = [
|
||||
"clock_core",
|
||||
"nssa_core",
|
||||
"risc0-zkvm",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
artifacts/test_program_methods/flash_swap_callback.bin
Normal file
BIN
artifacts/test_program_methods/flash_swap_callback.bin
Normal file
Binary file not shown.
BIN
artifacts/test_program_methods/flash_swap_initiator.bin
Normal file
BIN
artifacts/test_program_methods/flash_swap_initiator.bin
Normal file
Binary file not shown.
Binary file not shown.
BIN
artifacts/test_program_methods/malicious_caller_program_id.bin
Normal file
BIN
artifacts/test_program_methods/malicious_caller_program_id.bin
Normal file
Binary file not shown.
BIN
artifacts/test_program_methods/malicious_self_program_id.bin
Normal file
BIN
artifacts/test_program_methods/malicious_self_program_id.bin
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -20,6 +20,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: greeting,
|
||||
},
|
||||
@ -53,6 +54,7 @@ fn main() {
|
||||
// called to commit the output.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_data,
|
||||
vec![pre_state],
|
||||
vec![post_state],
|
||||
|
||||
@ -20,6 +20,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: greeting,
|
||||
},
|
||||
@ -60,6 +61,7 @@ fn main() {
|
||||
// called to commit the output.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_data,
|
||||
vec![pre_state],
|
||||
vec![post_state],
|
||||
|
||||
@ -67,6 +67,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (function_id, data),
|
||||
},
|
||||
@ -86,5 +87,12 @@ fn main() {
|
||||
|
||||
// WARNING: constructing a `ProgramOutput` has no effect on its own. `.write()` must be
|
||||
// called to commit the output.
|
||||
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.write();
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (),
|
||||
},
|
||||
@ -58,6 +59,7 @@ fn main() {
|
||||
// called to commit the output.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_data,
|
||||
vec![pre_state],
|
||||
vec![post_state],
|
||||
|
||||
@ -34,6 +34,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (),
|
||||
},
|
||||
@ -71,6 +72,7 @@ fn main() {
|
||||
// called to commit the output.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_data,
|
||||
vec![pre_state],
|
||||
vec![post_state],
|
||||
|
||||
@ -17,6 +17,7 @@ pub type ProgramId = [u32; 8];
|
||||
pub type InstructionData = Vec<u32>;
|
||||
pub struct ProgramInput<T> {
|
||||
pub self_program_id: ProgramId,
|
||||
pub caller_program_id: Option<ProgramId>,
|
||||
pub pre_states: Vec<AccountWithMetadata>,
|
||||
pub instruction: T,
|
||||
}
|
||||
@ -284,6 +285,9 @@ pub struct InvalidWindow;
|
||||
pub struct ProgramOutput {
|
||||
/// The program ID of the program that produced this output.
|
||||
pub self_program_id: ProgramId,
|
||||
/// The program ID of the caller that invoked this program via a chained call,
|
||||
/// or `None` if this is a top-level call.
|
||||
pub caller_program_id: Option<ProgramId>,
|
||||
/// The instruction data the program received to produce this output.
|
||||
pub instruction_data: InstructionData,
|
||||
/// The account pre states the program received to produce this output.
|
||||
@ -301,12 +305,14 @@ pub struct ProgramOutput {
|
||||
impl ProgramOutput {
|
||||
pub const fn new(
|
||||
self_program_id: ProgramId,
|
||||
caller_program_id: Option<ProgramId>,
|
||||
instruction_data: InstructionData,
|
||||
pre_states: Vec<AccountWithMetadata>,
|
||||
post_states: Vec<AccountPostState>,
|
||||
) -> Self {
|
||||
Self {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_data,
|
||||
pre_states,
|
||||
post_states,
|
||||
@ -421,12 +427,14 @@ pub fn compute_authorized_pdas(
|
||||
#[must_use]
|
||||
pub fn read_nssa_inputs<T: DeserializeOwned>() -> (ProgramInput<T>, InstructionData) {
|
||||
let self_program_id: ProgramId = env::read();
|
||||
let caller_program_id: Option<ProgramId> = env::read();
|
||||
let pre_states: Vec<AccountWithMetadata> = env::read();
|
||||
let instruction_words: InstructionData = env::read();
|
||||
let instruction = T::deserialize(&mut Deserializer::new(instruction_words.as_ref())).unwrap();
|
||||
(
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction,
|
||||
},
|
||||
@ -627,7 +635,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn program_output_try_with_block_validity_window_range() {
|
||||
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
|
||||
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
|
||||
.try_with_block_validity_window(10_u64..100)
|
||||
.unwrap();
|
||||
assert_eq!(output.block_validity_window.start(), Some(10));
|
||||
@ -636,7 +644,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn program_output_with_block_validity_window_range_from() {
|
||||
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
|
||||
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
|
||||
.with_block_validity_window(10_u64..);
|
||||
assert_eq!(output.block_validity_window.start(), Some(10));
|
||||
assert_eq!(output.block_validity_window.end(), None);
|
||||
@ -644,7 +652,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn program_output_with_block_validity_window_range_to() {
|
||||
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
|
||||
let output = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
|
||||
.with_block_validity_window(..100_u64);
|
||||
assert_eq!(output.block_validity_window.start(), None);
|
||||
assert_eq!(output.block_validity_window.end(), Some(100));
|
||||
@ -652,7 +660,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn program_output_try_with_block_validity_window_empty_range_fails() {
|
||||
let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, vec![], vec![], vec![])
|
||||
let result = ProgramOutput::new(DEFAULT_PROGRAM_ID, None, vec![], vec![], vec![])
|
||||
.try_with_block_validity_window(5_u64..5);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@ -87,15 +87,16 @@ pub fn execute_and_prove(
|
||||
pda_seeds: vec![],
|
||||
};
|
||||
|
||||
let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program)]);
|
||||
let mut chained_calls = VecDeque::from_iter([(initial_call, initial_program, None)]);
|
||||
let mut chain_calls_counter = 0;
|
||||
while let Some((chained_call, program)) = chained_calls.pop_front() {
|
||||
while let Some((chained_call, program, caller_program_id)) = chained_calls.pop_front() {
|
||||
if chain_calls_counter >= MAX_NUMBER_CHAINED_CALLS {
|
||||
return Err(NssaError::MaxChainedCallsDepthExceeded);
|
||||
}
|
||||
|
||||
let inner_receipt = execute_and_prove_program(
|
||||
program,
|
||||
caller_program_id,
|
||||
&chained_call.pre_states,
|
||||
&chained_call.instruction_data,
|
||||
)?;
|
||||
@ -115,7 +116,7 @@ pub fn execute_and_prove(
|
||||
let next_program = dependencies
|
||||
.get(&new_call.program_id)
|
||||
.ok_or(NssaError::InvalidProgramBehavior)?;
|
||||
chained_calls.push_front((new_call, next_program));
|
||||
chained_calls.push_front((new_call, next_program, Some(chained_call.program_id)));
|
||||
}
|
||||
|
||||
chain_calls_counter = chain_calls_counter
|
||||
@ -153,12 +154,19 @@ pub fn execute_and_prove(
|
||||
|
||||
fn execute_and_prove_program(
|
||||
program: &Program,
|
||||
caller_program_id: Option<ProgramId>,
|
||||
pre_states: &[AccountWithMetadata],
|
||||
instruction_data: &InstructionData,
|
||||
) -> Result<Receipt, NssaError> {
|
||||
// Write inputs to the program
|
||||
let mut env_builder = ExecutorEnv::builder();
|
||||
Program::write_inputs(program.id(), pre_states, instruction_data, &mut env_builder)?;
|
||||
Program::write_inputs(
|
||||
program.id(),
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction_data,
|
||||
&mut env_builder,
|
||||
)?;
|
||||
let env = env_builder.build().unwrap();
|
||||
|
||||
// Prove the program
|
||||
|
||||
@ -53,13 +53,20 @@ impl Program {
|
||||
|
||||
pub(crate) fn execute(
|
||||
&self,
|
||||
caller_program_id: Option<ProgramId>,
|
||||
pre_states: &[AccountWithMetadata],
|
||||
instruction_data: &InstructionData,
|
||||
) -> Result<ProgramOutput, NssaError> {
|
||||
// Write inputs to the program
|
||||
let mut env_builder = ExecutorEnv::builder();
|
||||
env_builder.session_limit(Some(MAX_NUM_CYCLES_PUBLIC_EXECUTION));
|
||||
Self::write_inputs(self.id, pre_states, instruction_data, &mut env_builder)?;
|
||||
Self::write_inputs(
|
||||
self.id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction_data,
|
||||
&mut env_builder,
|
||||
)?;
|
||||
let env = env_builder.build().unwrap();
|
||||
|
||||
// Execute the program (without proving)
|
||||
@ -80,6 +87,7 @@ impl Program {
|
||||
/// Writes inputs to `env_builder` in the order expected by the programs.
|
||||
pub(crate) fn write_inputs(
|
||||
program_id: ProgramId,
|
||||
caller_program_id: Option<ProgramId>,
|
||||
pre_states: &[AccountWithMetadata],
|
||||
instruction_data: &[u32],
|
||||
env_builder: &mut ExecutorEnvBuilder,
|
||||
@ -87,6 +95,9 @@ impl Program {
|
||||
env_builder
|
||||
.write(&program_id)
|
||||
.map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?;
|
||||
env_builder
|
||||
.write(&caller_program_id)
|
||||
.map_err(|e| NssaError::ProgramWriteInputFailed(e.to_string()))?;
|
||||
let pre_states = pre_states.to_vec();
|
||||
env_builder
|
||||
.write(&pre_states)
|
||||
@ -320,6 +331,34 @@ mod tests {
|
||||
Self::new(VALIDITY_WINDOW_CHAIN_CALLER_ELF.to_vec()).unwrap()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn flash_swap_initiator() -> Self {
|
||||
use test_program_methods::FLASH_SWAP_INITIATOR_ELF;
|
||||
Self::new(FLASH_SWAP_INITIATOR_ELF.to_vec())
|
||||
.expect("flash_swap_initiator must be a valid Risc0 program")
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn flash_swap_callback() -> Self {
|
||||
use test_program_methods::FLASH_SWAP_CALLBACK_ELF;
|
||||
Self::new(FLASH_SWAP_CALLBACK_ELF.to_vec())
|
||||
.expect("flash_swap_callback must be a valid Risc0 program")
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn malicious_self_program_id() -> Self {
|
||||
use test_program_methods::MALICIOUS_SELF_PROGRAM_ID_ELF;
|
||||
Self::new(MALICIOUS_SELF_PROGRAM_ID_ELF.to_vec())
|
||||
.expect("malicious_self_program_id must be a valid Risc0 program")
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn malicious_caller_program_id() -> Self {
|
||||
use test_program_methods::MALICIOUS_CALLER_PROGRAM_ID_ELF;
|
||||
Self::new(MALICIOUS_CALLER_PROGRAM_ID_ELF.to_vec())
|
||||
.expect("malicious_caller_program_id must be a valid Risc0 program")
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn time_locked_transfer() -> Self {
|
||||
use test_program_methods::TIME_LOCKED_TRANSFER_ELF;
|
||||
@ -358,7 +397,7 @@ mod tests {
|
||||
..Account::default()
|
||||
};
|
||||
let program_output = program
|
||||
.execute(&[sender, recipient], &instruction_data)
|
||||
.execute(None, &[sender, recipient], &instruction_data)
|
||||
.unwrap();
|
||||
|
||||
let [sender_post, recipient_post] = program_output.post_states.try_into().unwrap();
|
||||
|
||||
@ -400,6 +400,10 @@ pub mod tests {
|
||||
self.insert_program(Program::claimer());
|
||||
self.insert_program(Program::changer_claimer());
|
||||
self.insert_program(Program::validity_window());
|
||||
self.insert_program(Program::flash_swap_initiator());
|
||||
self.insert_program(Program::flash_swap_callback());
|
||||
self.insert_program(Program::malicious_self_program_id());
|
||||
self.insert_program(Program::malicious_caller_program_id());
|
||||
self.insert_program(Program::time_locked_transfer());
|
||||
self.insert_program(Program::pinata_cooldown());
|
||||
self
|
||||
@ -478,6 +482,28 @@ pub mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Flash Swap types (mirrors of guest types for host-side serialisation) ──
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct CallbackInstruction {
|
||||
return_funds: bool,
|
||||
token_program_id: ProgramId,
|
||||
amount: u128,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
enum FlashSwapInstruction {
|
||||
Initiate {
|
||||
token_program_id: ProgramId,
|
||||
callback_program_id: ProgramId,
|
||||
amount_out: u128,
|
||||
callback_instruction_data: Vec<u32>,
|
||||
},
|
||||
InvariantCheck {
|
||||
min_vault_balance: u128,
|
||||
},
|
||||
}
|
||||
|
||||
fn transfer_transaction(
|
||||
from: AccountId,
|
||||
from_key: &PrivateKey,
|
||||
@ -497,6 +523,23 @@ pub mod tests {
|
||||
PublicTransaction::new(message, witness_set)
|
||||
}
|
||||
|
||||
fn build_flash_swap_tx(
|
||||
initiator: &Program,
|
||||
vault_id: AccountId,
|
||||
receiver_id: AccountId,
|
||||
instruction: FlashSwapInstruction,
|
||||
) -> PublicTransaction {
|
||||
let message = public_transaction::Message::try_new(
|
||||
initiator.id(),
|
||||
vec![vault_id, receiver_id],
|
||||
vec![], // no signers — vault is PDA-authorised
|
||||
instruction,
|
||||
)
|
||||
.unwrap();
|
||||
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
|
||||
PublicTransaction::new(message, witness_set)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_with_genesis() {
|
||||
let key1 = PrivateKey::try_new([1; 32]).unwrap();
|
||||
@ -3877,4 +3920,242 @@ pub mod tests {
|
||||
let state_from_bytes: V03State = borsh::from_slice(&bytes).unwrap();
|
||||
assert_eq!(state, state_from_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_swap_successful() {
|
||||
let initiator = Program::flash_swap_initiator();
|
||||
let callback = Program::flash_swap_callback();
|
||||
let token = Program::authenticated_transfer_program();
|
||||
|
||||
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
|
||||
let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32])));
|
||||
|
||||
let initial_balance: u128 = 1000;
|
||||
let amount_out: u128 = 100;
|
||||
|
||||
let vault_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: initial_balance,
|
||||
..Account::default()
|
||||
};
|
||||
let receiver_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: 0,
|
||||
..Account::default()
|
||||
};
|
||||
|
||||
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
|
||||
state.force_insert_account(vault_id, vault_account);
|
||||
state.force_insert_account(receiver_id, receiver_account);
|
||||
|
||||
// Callback instruction: return funds
|
||||
let cb_instruction = CallbackInstruction {
|
||||
return_funds: true,
|
||||
token_program_id: token.id(),
|
||||
amount: amount_out,
|
||||
};
|
||||
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
|
||||
|
||||
let instruction = FlashSwapInstruction::Initiate {
|
||||
token_program_id: token.id(),
|
||||
callback_program_id: callback.id(),
|
||||
amount_out,
|
||||
callback_instruction_data: cb_data,
|
||||
};
|
||||
|
||||
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
|
||||
let result = state.transition_from_public_transaction(&tx, 1, 0);
|
||||
assert!(result.is_ok(), "flash swap should succeed: {result:?}");
|
||||
|
||||
// Vault balance restored, receiver back to 0
|
||||
assert_eq!(state.get_account_by_id(vault_id).balance, initial_balance);
|
||||
assert_eq!(state.get_account_by_id(receiver_id).balance, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_swap_callback_keeps_funds_rollback() {
|
||||
let initiator = Program::flash_swap_initiator();
|
||||
let callback = Program::flash_swap_callback();
|
||||
let token = Program::authenticated_transfer_program();
|
||||
|
||||
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
|
||||
let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32])));
|
||||
|
||||
let initial_balance: u128 = 1000;
|
||||
let amount_out: u128 = 100;
|
||||
|
||||
let vault_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: initial_balance,
|
||||
..Account::default()
|
||||
};
|
||||
let receiver_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: 0,
|
||||
..Account::default()
|
||||
};
|
||||
|
||||
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
|
||||
state.force_insert_account(vault_id, vault_account);
|
||||
state.force_insert_account(receiver_id, receiver_account);
|
||||
|
||||
// Callback instruction: do NOT return funds
|
||||
let cb_instruction = CallbackInstruction {
|
||||
return_funds: false,
|
||||
token_program_id: token.id(),
|
||||
amount: amount_out,
|
||||
};
|
||||
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
|
||||
|
||||
let instruction = FlashSwapInstruction::Initiate {
|
||||
token_program_id: token.id(),
|
||||
callback_program_id: callback.id(),
|
||||
amount_out,
|
||||
callback_instruction_data: cb_data,
|
||||
};
|
||||
|
||||
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
|
||||
let result = state.transition_from_public_transaction(&tx, 1, 0);
|
||||
|
||||
// Invariant check fails → entire tx rolls back
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"flash swap should fail when callback keeps funds"
|
||||
);
|
||||
|
||||
// State unchanged (rollback)
|
||||
assert_eq!(state.get_account_by_id(vault_id).balance, initial_balance);
|
||||
assert_eq!(state.get_account_by_id(receiver_id).balance, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_swap_self_call_targets_correct_program() {
|
||||
// Zero-amount flash swap: the invariant self-call still runs and succeeds
|
||||
// because vault balance doesn't decrease.
|
||||
let initiator = Program::flash_swap_initiator();
|
||||
let callback = Program::flash_swap_callback();
|
||||
let token = Program::authenticated_transfer_program();
|
||||
|
||||
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
|
||||
let receiver_id = AccountId::from((&callback.id(), &PdaSeed::new([1_u8; 32])));
|
||||
|
||||
let initial_balance: u128 = 1000;
|
||||
|
||||
let vault_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: initial_balance,
|
||||
..Account::default()
|
||||
};
|
||||
let receiver_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: 0,
|
||||
..Account::default()
|
||||
};
|
||||
|
||||
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
|
||||
state.force_insert_account(vault_id, vault_account);
|
||||
state.force_insert_account(receiver_id, receiver_account);
|
||||
|
||||
let cb_instruction = CallbackInstruction {
|
||||
return_funds: true,
|
||||
token_program_id: token.id(),
|
||||
amount: 0,
|
||||
};
|
||||
let cb_data = Program::serialize_instruction(cb_instruction).unwrap();
|
||||
|
||||
let instruction = FlashSwapInstruction::Initiate {
|
||||
token_program_id: token.id(),
|
||||
callback_program_id: callback.id(),
|
||||
amount_out: 0,
|
||||
callback_instruction_data: cb_data,
|
||||
};
|
||||
|
||||
let tx = build_flash_swap_tx(&initiator, vault_id, receiver_id, instruction);
|
||||
let result = state.transition_from_public_transaction(&tx, 1, 0);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"zero-amount flash swap should succeed: {result:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_swap_standalone_invariant_check_rejected() {
|
||||
// Calling InvariantCheck directly (not as a chained self-call) should fail
|
||||
// because caller_program_id will be None.
|
||||
let initiator = Program::flash_swap_initiator();
|
||||
let token = Program::authenticated_transfer_program();
|
||||
|
||||
let vault_id = AccountId::from((&initiator.id(), &PdaSeed::new([0_u8; 32])));
|
||||
|
||||
let vault_account = Account {
|
||||
program_owner: token.id(),
|
||||
balance: 1000,
|
||||
..Account::default()
|
||||
};
|
||||
|
||||
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
|
||||
state.force_insert_account(vault_id, vault_account);
|
||||
|
||||
let instruction = FlashSwapInstruction::InvariantCheck {
|
||||
min_vault_balance: 1000,
|
||||
};
|
||||
|
||||
let message = public_transaction::Message::try_new(
|
||||
initiator.id(),
|
||||
vec![vault_id],
|
||||
vec![],
|
||||
instruction,
|
||||
)
|
||||
.unwrap();
|
||||
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
|
||||
let tx = PublicTransaction::new(message, witness_set);
|
||||
|
||||
let result = state.transition_from_public_transaction(&tx, 1, 0);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"standalone InvariantCheck should be rejected (caller_program_id is None)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malicious_self_program_id_rejected_in_public_execution() {
|
||||
let program = Program::malicious_self_program_id();
|
||||
let acc_id = AccountId::new([99; 32]);
|
||||
let account = Account::default();
|
||||
|
||||
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
|
||||
state.force_insert_account(acc_id, account);
|
||||
|
||||
let message =
|
||||
public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap();
|
||||
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
|
||||
let tx = PublicTransaction::new(message, witness_set);
|
||||
|
||||
let result = state.transition_from_public_transaction(&tx, 1, 0);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"program with wrong self_program_id in output should be rejected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malicious_caller_program_id_rejected_in_public_execution() {
|
||||
let program = Program::malicious_caller_program_id();
|
||||
let acc_id = AccountId::new([99; 32]);
|
||||
let account = Account::default();
|
||||
|
||||
let mut state = V03State::new_with_genesis_accounts(&[], &[], 0).with_test_programs();
|
||||
state.force_insert_account(acc_id, account);
|
||||
|
||||
let message =
|
||||
public_transaction::Message::try_new(program.id(), vec![acc_id], vec![], ()).unwrap();
|
||||
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
|
||||
let tx = PublicTransaction::new(message, witness_set);
|
||||
|
||||
let result = state.transition_from_public_transaction(&tx, 1, 0);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"program with spoofed caller_program_id in output should be rejected"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,8 +118,11 @@ impl ValidatedStateDiff {
|
||||
"Program {:?} pre_states: {:?}, instruction_data: {:?}",
|
||||
chained_call.program_id, chained_call.pre_states, chained_call.instruction_data
|
||||
);
|
||||
let mut program_output =
|
||||
program.execute(&chained_call.pre_states, &chained_call.instruction_data)?;
|
||||
let mut program_output = program.execute(
|
||||
caller_program_id,
|
||||
&chained_call.pre_states,
|
||||
&chained_call.instruction_data,
|
||||
)?;
|
||||
debug!(
|
||||
"Program {:?} output: {:?}",
|
||||
chained_call.program_id, program_output
|
||||
@ -159,6 +162,12 @@ impl ValidatedStateDiff {
|
||||
NssaError::InvalidProgramBehavior
|
||||
);
|
||||
|
||||
// Verify that the program output's caller_program_id matches the actual caller.
|
||||
ensure!(
|
||||
program_output.caller_program_id == caller_program_id,
|
||||
NssaError::InvalidProgramBehavior
|
||||
);
|
||||
|
||||
// Verify execution corresponds to a well-behaved program.
|
||||
// See the # Programs section for the definition of the `validate_execution` method.
|
||||
ensure!(
|
||||
|
||||
@ -15,6 +15,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction,
|
||||
},
|
||||
@ -155,6 +156,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states_clone,
|
||||
post_states,
|
||||
|
||||
@ -5,6 +5,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction,
|
||||
},
|
||||
@ -59,6 +60,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states_clone,
|
||||
post_states,
|
||||
|
||||
@ -68,6 +68,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: balance_to_move,
|
||||
},
|
||||
@ -85,5 +86,12 @@ fn main() {
|
||||
_ => panic!("invalid params"),
|
||||
};
|
||||
|
||||
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.write();
|
||||
}
|
||||
|
||||
@ -40,6 +40,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: timestamp,
|
||||
},
|
||||
@ -84,6 +85,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre_01, pre_10, pre_50],
|
||||
vec![post_01, post_10, post_50],
|
||||
|
||||
@ -47,6 +47,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: solution,
|
||||
},
|
||||
@ -81,6 +82,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pinata, winner],
|
||||
vec![
|
||||
|
||||
@ -53,6 +53,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: solution,
|
||||
},
|
||||
@ -99,6 +100,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![
|
||||
pinata_definition,
|
||||
|
||||
@ -114,6 +114,15 @@ impl ExecutionState {
|
||||
"Program output self_program_id does not match chained call program_id"
|
||||
);
|
||||
|
||||
// Verify that the program output's caller_program_id matches the actual caller.
|
||||
// This prevents a malicious user from privately executing an internal function
|
||||
// by spoofing caller_program_id (e.g. passing caller_program_id = self_program_id
|
||||
// to bypass access control checks).
|
||||
assert_eq!(
|
||||
program_output.caller_program_id, caller_program_id,
|
||||
"Program output caller_program_id does not match actual caller"
|
||||
);
|
||||
|
||||
// Check that the program is well behaved.
|
||||
// See the # Programs section for the definition of the `validate_execution` method.
|
||||
let execution_valid = validate_execution(
|
||||
|
||||
@ -13,6 +13,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction,
|
||||
},
|
||||
@ -84,6 +85,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states_clone,
|
||||
post_states,
|
||||
|
||||
@ -12,3 +12,4 @@ nssa_core.workspace = true
|
||||
clock_core.workspace = true
|
||||
|
||||
risc0-zkvm.workspace = true
|
||||
serde = { workspace = true, default-features = false }
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: balance_to_burn,
|
||||
},
|
||||
@ -22,6 +23,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new(account_post)],
|
||||
|
||||
@ -14,6 +14,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (balance, auth_transfer_id, num_chain_calls, pda_seed),
|
||||
},
|
||||
@ -57,6 +58,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![sender_pre.clone(), recipient_pre.clone()],
|
||||
vec![
|
||||
|
||||
@ -7,6 +7,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (data_opt, should_claim),
|
||||
},
|
||||
@ -36,6 +37,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![post_state],
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (),
|
||||
},
|
||||
@ -20,6 +21,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![account_post],
|
||||
|
||||
@ -15,6 +15,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (clock_program_id, timestamp),
|
||||
},
|
||||
@ -33,7 +34,13 @@ fn main() {
|
||||
pda_seeds: vec![],
|
||||
};
|
||||
|
||||
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states)
|
||||
.with_chained_calls(vec![chained_call])
|
||||
.write();
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.with_chained_calls(vec![chained_call])
|
||||
.write();
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: data,
|
||||
},
|
||||
@ -25,6 +26,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new_claimed(
|
||||
|
||||
@ -9,6 +9,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
..
|
||||
},
|
||||
@ -23,6 +24,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![
|
||||
|
||||
94
test_program_methods/guest/src/bin/flash_swap_callback.rs
Normal file
94
test_program_methods/guest/src/bin/flash_swap_callback.rs
Normal file
@ -0,0 +1,94 @@
|
||||
//! Flash swap callback, the user logic step in the "prep → callback → assert" pattern.
|
||||
//!
|
||||
//! # Role
|
||||
//!
|
||||
//! This program is called as chained call 2 in the flash swap sequence:
|
||||
//! 1. Token transfer out (vault → receiver)
|
||||
//! 2. **This callback** (user logic)
|
||||
//! 3. Invariant check (assert vault balance restored)
|
||||
//!
|
||||
//! In a real flash swap, this would contain the user's arbitrage or other logic.
|
||||
//! In this test program, it is controlled by `return_funds`:
|
||||
//!
|
||||
//! - `return_funds = true`: emits a token transfer (receiver → vault) to return the funds. The
|
||||
//! invariant check will pass and the transaction will succeed.
|
||||
//!
|
||||
//! - `return_funds = false`: emits no transfers. Funds stay with the receiver. The invariant check
|
||||
//! will fail (vault balance < initial), causing full atomic rollback. This simulates a malicious
|
||||
//! or buggy callback that does not repay the flash loan.
|
||||
//!
|
||||
//! # Note on `caller_program_id`
|
||||
//!
|
||||
//! This program does not enforce any access control on `caller_program_id`.
|
||||
//! It is designed to be called by the flash swap initiator but could in principle be
|
||||
//! called by any program. In production, a callback would typically verify the caller
|
||||
//! if it needs to trust the context it is called from.
|
||||
|
||||
use nssa_core::program::{
|
||||
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
|
||||
read_nssa_inputs,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CallbackInstruction {
|
||||
/// If true, return the borrowed funds to the vault (happy path).
|
||||
/// If false, keep the funds (simulates a malicious callback, triggers rollback).
|
||||
pub return_funds: bool,
|
||||
pub token_program_id: ProgramId,
|
||||
pub amount: u128,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id, // not enforced in this callback
|
||||
pre_states,
|
||||
instruction,
|
||||
},
|
||||
instruction_words,
|
||||
) = read_nssa_inputs::<CallbackInstruction>();
|
||||
|
||||
// pre_states[0] = vault (after transfer out), pre_states[1] = receiver (after transfer out)
|
||||
let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else {
|
||||
panic!("Callback requires exactly 2 accounts: vault, receiver");
|
||||
};
|
||||
|
||||
let mut chained_calls = Vec::new();
|
||||
|
||||
if instruction.return_funds {
|
||||
// Happy path: return the borrowed funds via a token transfer (receiver → vault).
|
||||
// The receiver is a PDA of this callback program (seed = [1_u8; 32]).
|
||||
// Mark the receiver as authorized since it will be PDA-authorized in this chained call.
|
||||
let mut receiver_authorized = receiver_pre.clone();
|
||||
receiver_authorized.is_authorized = true;
|
||||
let transfer_instruction = risc0_zkvm::serde::to_vec(&instruction.amount)
|
||||
.expect("transfer instruction serialization");
|
||||
|
||||
chained_calls.push(ChainedCall {
|
||||
program_id: instruction.token_program_id,
|
||||
pre_states: vec![receiver_authorized, vault_pre.clone()],
|
||||
instruction_data: transfer_instruction,
|
||||
pda_seeds: vec![PdaSeed::new([1_u8; 32])],
|
||||
});
|
||||
}
|
||||
// Malicious path (return_funds = false): emit no chained calls.
|
||||
// The vault balance will not be restored, so the invariant check in the initiator
|
||||
// will panic, rolling back the entire transaction including the initial transfer out.
|
||||
|
||||
// The callback itself makes no direct state changes, accounts pass through unchanged.
|
||||
// All mutations go through the token program via chained calls.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![vault_pre.clone(), receiver_pre.clone()],
|
||||
vec![
|
||||
AccountPostState::new(vault_pre.account),
|
||||
AccountPostState::new(receiver_pre.account),
|
||||
],
|
||||
)
|
||||
.with_chained_calls(chained_calls)
|
||||
.write();
|
||||
}
|
||||
216
test_program_methods/guest/src/bin/flash_swap_initiator.rs
Normal file
216
test_program_methods/guest/src/bin/flash_swap_initiator.rs
Normal file
@ -0,0 +1,216 @@
|
||||
//! Flash swap initiator, demonstrates the "prep → callback → assert" pattern using
|
||||
//! generalized multi tail-calls with `self_program_id` and `caller_program_id`.
|
||||
//!
|
||||
//! # Pattern
|
||||
//!
|
||||
//! A flash swap lets a program optimistically transfer tokens out, run arbitrary user
|
||||
//! logic (the callback), then assert that invariants hold after the callback. The entire
|
||||
//! sequence is a single atomic transaction: if any step fails, all state changes roll back.
|
||||
//!
|
||||
//! # How it works
|
||||
//!
|
||||
//! This program handles two instruction variants:
|
||||
//!
|
||||
//! - `Initiate` (external): the top-level entrypoint. Emits 3 chained calls:
|
||||
//! 1. Token transfer out (vault → receiver)
|
||||
//! 2. User callback (arbitrary logic, e.g. arbitrage)
|
||||
//! 3. Self-call to `InvariantCheck` (using `self_program_id` to reference itself)
|
||||
//!
|
||||
//! - `InvariantCheck` (internal): enforces that the vault balance was restored after the callback.
|
||||
//! Uses `caller_program_id == Some(self_program_id)` to prevent standalone calls (this is the
|
||||
//! visibility enforcement mechanism).
|
||||
//!
|
||||
//! # What this demonstrates
|
||||
//!
|
||||
//! - `self_program_id`: enables a program to chain back to itself (step 3 above)
|
||||
//! - `caller_program_id`: enables a program to restrict which callers can invoke an instruction
|
||||
//! - Computed intermediate states: the initiator computes expected intermediate account states from
|
||||
//! the `pre_states` and amount, keeping the instruction minimal.
|
||||
//! - Atomic rollback: if the callback doesn't return funds, the invariant check fails, and all
|
||||
//! state changes from steps 1 and 2 are rolled back automatically.
|
||||
//!
|
||||
//! # Tests
|
||||
//!
|
||||
//! See `nssa/src/state.rs` for integration tests:
|
||||
//! - `flash_swap_successful`: full round-trip, funds returned, state unchanged
|
||||
//! - `flash_swap_callback_keeps_funds_rollback`: callback keeps funds, full rollback
|
||||
//! - `flash_swap_self_call_targets_correct_program`: zero-amount self-call isolation test
|
||||
//! - `flash_swap_standalone_invariant_check_rejected`: `caller_program_id` access control
|
||||
|
||||
use nssa_core::program::{
|
||||
AccountPostState, ChainedCall, PdaSeed, ProgramId, ProgramInput, ProgramOutput,
|
||||
read_nssa_inputs,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum FlashSwapInstruction {
|
||||
/// External entrypoint: initiate a flash swap.
|
||||
///
|
||||
/// Emits 3 chained calls:
|
||||
/// 1. Token transfer (vault → receiver, `amount_out`)
|
||||
/// 2. Callback (user logic, e.g. arbitrage)
|
||||
/// 3. Self-call `InvariantCheck` (verify vault balance did not decrease)
|
||||
///
|
||||
/// Intermediate account states are computed inside the program from `pre_states` and
|
||||
/// `amount_out`.
|
||||
Initiate {
|
||||
token_program_id: ProgramId,
|
||||
callback_program_id: ProgramId,
|
||||
amount_out: u128,
|
||||
callback_instruction_data: Vec<u32>,
|
||||
},
|
||||
/// Internal: verify the vault invariant holds after callback execution.
|
||||
///
|
||||
/// Access control: only callable as a chained call from this program itself.
|
||||
/// This is enforced by checking `caller_program_id == Some(self_program_id)`.
|
||||
/// Any attempt to call this instruction as a standalone top-level transaction
|
||||
/// will be rejected because `caller_program_id` will be `None`.
|
||||
InvariantCheck { min_vault_balance: u128 },
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction,
|
||||
},
|
||||
instruction_words,
|
||||
) = read_nssa_inputs::<FlashSwapInstruction>();
|
||||
|
||||
match instruction {
|
||||
FlashSwapInstruction::Initiate {
|
||||
token_program_id,
|
||||
callback_program_id,
|
||||
amount_out,
|
||||
callback_instruction_data,
|
||||
} => {
|
||||
let Ok([vault_pre, receiver_pre]) = <[_; 2]>::try_from(pre_states) else {
|
||||
panic!("Initiate requires exactly 2 accounts: vault, receiver");
|
||||
};
|
||||
|
||||
// Capture initial vault balance, the invariant check will verify it is restored.
|
||||
let min_vault_balance = vault_pre.account.balance;
|
||||
|
||||
// Compute intermediate account states from pre_states and amount_out.
|
||||
let mut vault_after_transfer = vault_pre.clone();
|
||||
vault_after_transfer.account.balance = vault_pre
|
||||
.account
|
||||
.balance
|
||||
.checked_sub(amount_out)
|
||||
.expect("vault has insufficient balance for flash swap");
|
||||
|
||||
let mut receiver_after_transfer = receiver_pre.clone();
|
||||
receiver_after_transfer.account.balance = receiver_pre
|
||||
.account
|
||||
.balance
|
||||
.checked_add(amount_out)
|
||||
.expect("receiver balance overflow");
|
||||
|
||||
let mut vault_after_callback = vault_after_transfer.clone();
|
||||
vault_after_callback.account.balance = vault_after_transfer
|
||||
.account
|
||||
.balance
|
||||
.checked_add(amount_out)
|
||||
.expect("vault balance overflow after callback");
|
||||
|
||||
// Chained call 1: Token transfer (vault → receiver).
|
||||
// The vault is a PDA of this initiator program (seed = [0_u8; 32]), so we provide
|
||||
// the PDA seed to authorize the token program to debit the vault on our behalf.
|
||||
// Mark the vault as authorized since it will be PDA-authorized in this chained call.
|
||||
let mut vault_authorized = vault_pre.clone();
|
||||
vault_authorized.is_authorized = true;
|
||||
let transfer_instruction =
|
||||
risc0_zkvm::serde::to_vec(&amount_out).expect("transfer instruction serialization");
|
||||
let call_1 = ChainedCall {
|
||||
program_id: token_program_id,
|
||||
pre_states: vec![vault_authorized, receiver_pre.clone()],
|
||||
instruction_data: transfer_instruction,
|
||||
pda_seeds: vec![PdaSeed::new([0_u8; 32])],
|
||||
};
|
||||
|
||||
// Chained call 2: User callback.
|
||||
// Receives the post-transfer states as its pre_states. The callback may run
|
||||
// arbitrary logic (arbitrage, etc.) and is expected to return funds to the vault.
|
||||
let call_2 = ChainedCall {
|
||||
program_id: callback_program_id,
|
||||
pre_states: vec![vault_after_transfer, receiver_after_transfer],
|
||||
instruction_data: callback_instruction_data,
|
||||
pda_seeds: vec![],
|
||||
};
|
||||
|
||||
// Chained call 3: Self-call to enforce the invariant.
|
||||
// Uses `self_program_id` to reference this program, the key feature that enables
|
||||
// the "prep → callback → assert" pattern without a separate checker program.
|
||||
// If the callback did not return funds, vault_after_callback.balance <
|
||||
// min_vault_balance and this call will panic, rolling back the entire
|
||||
// transaction.
|
||||
let invariant_instruction =
|
||||
risc0_zkvm::serde::to_vec(&FlashSwapInstruction::InvariantCheck {
|
||||
min_vault_balance,
|
||||
})
|
||||
.expect("invariant instruction serialization");
|
||||
let call_3 = ChainedCall {
|
||||
program_id: self_program_id, // self-referential chained call
|
||||
pre_states: vec![vault_after_callback],
|
||||
instruction_data: invariant_instruction,
|
||||
pda_seeds: vec![],
|
||||
};
|
||||
|
||||
// The initiator itself makes no direct state changes.
|
||||
// All mutations happen inside the chained calls (token transfers).
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![vault_pre.clone(), receiver_pre.clone()],
|
||||
vec![
|
||||
AccountPostState::new(vault_pre.account),
|
||||
AccountPostState::new(receiver_pre.account),
|
||||
],
|
||||
)
|
||||
.with_chained_calls(vec![call_1, call_2, call_3])
|
||||
.write();
|
||||
}
|
||||
|
||||
FlashSwapInstruction::InvariantCheck { min_vault_balance } => {
|
||||
// Visibility enforcement: `InvariantCheck` is an internal instruction.
|
||||
// It must only be called as a chained call from this program itself (via `Initiate`).
|
||||
// When called as a top-level transaction, `caller_program_id` is `None` → panics.
|
||||
// When called as a chained call from `Initiate`, `caller_program_id` is
|
||||
// `Some(self_program_id)` → passes.
|
||||
assert_eq!(
|
||||
caller_program_id,
|
||||
Some(self_program_id),
|
||||
"InvariantCheck is an internal instruction: must be called by flash_swap_initiator \
|
||||
via a chained call",
|
||||
);
|
||||
|
||||
let Ok([vault]) = <[_; 1]>::try_from(pre_states) else {
|
||||
panic!("InvariantCheck requires exactly 1 account: vault");
|
||||
};
|
||||
|
||||
// The core invariant: vault balance must not have decreased.
|
||||
// If the callback returned funds, this passes. If not, this panics and
|
||||
// the entire transaction (including the prior token transfer) rolls back.
|
||||
assert!(
|
||||
vault.account.balance >= min_vault_balance,
|
||||
"Flash swap invariant violated: vault balance {} < minimum {}",
|
||||
vault.account.balance,
|
||||
min_vault_balance
|
||||
);
|
||||
|
||||
// Pass-through: no state changes in the invariant check step.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![vault.clone()],
|
||||
vec![AccountPostState::new(vault.account)],
|
||||
)
|
||||
.write();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -15,6 +15,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (balance, transfer_program_id),
|
||||
},
|
||||
@ -42,6 +43,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![sender.clone(), receiver.clone()],
|
||||
vec![
|
||||
|
||||
@ -0,0 +1,34 @@
|
||||
use nssa_core::program::{
|
||||
AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs,
|
||||
};
|
||||
|
||||
type Instruction = ();
|
||||
|
||||
fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id: _, // ignore the actual caller
|
||||
pre_states,
|
||||
instruction: (),
|
||||
},
|
||||
instruction_words,
|
||||
) = read_nssa_inputs::<Instruction>();
|
||||
|
||||
let post_states = pre_states
|
||||
.iter()
|
||||
.map(|a| AccountPostState::new(a.account.clone()))
|
||||
.collect();
|
||||
|
||||
// Deliberately output wrong caller_program_id.
|
||||
// A real caller_program_id is None for a top-level call, so we spoof Some(DEFAULT_PROGRAM_ID)
|
||||
// to simulate a program claiming it was invoked by another program when it was not.
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
Some(DEFAULT_PROGRAM_ID), // WRONG: should be None for a top-level call
|
||||
instruction_words,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.write();
|
||||
}
|
||||
@ -0,0 +1,32 @@
|
||||
use nssa_core::program::{
|
||||
AccountPostState, DEFAULT_PROGRAM_ID, ProgramInput, ProgramOutput, read_nssa_inputs,
|
||||
};
|
||||
|
||||
type Instruction = ();
|
||||
|
||||
fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id: _, // ignore the correct ID
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (),
|
||||
},
|
||||
instruction_words,
|
||||
) = read_nssa_inputs::<Instruction>();
|
||||
|
||||
let post_states = pre_states
|
||||
.iter()
|
||||
.map(|a| AccountPostState::new(a.account.clone()))
|
||||
.collect();
|
||||
|
||||
// Deliberately output wrong self_program_id
|
||||
ProgramOutput::new(
|
||||
DEFAULT_PROGRAM_ID, // WRONG: should be self_program_id
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.write();
|
||||
}
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
..
|
||||
},
|
||||
@ -25,6 +26,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new(account_post)],
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
..
|
||||
},
|
||||
@ -20,6 +21,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre1, pre2],
|
||||
vec![AccountPostState::new(account_pre1)],
|
||||
|
||||
@ -65,6 +65,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: balance_to_move,
|
||||
},
|
||||
@ -81,5 +82,12 @@ fn main() {
|
||||
}
|
||||
_ => panic!("invalid params"),
|
||||
};
|
||||
ProgramOutput::new(self_program_id, instruction_data, pre_states, post_states).write();
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_data,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.write();
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
..
|
||||
},
|
||||
@ -22,6 +23,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new(account_post)],
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
..
|
||||
},
|
||||
@ -16,5 +17,12 @@ fn main() {
|
||||
.iter()
|
||||
.map(|account| AccountPostState::new(account.account.clone()))
|
||||
.collect();
|
||||
ProgramOutput::new(self_program_id, instruction_words, pre_states, post_states).write();
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
pre_states,
|
||||
post_states,
|
||||
)
|
||||
.write();
|
||||
}
|
||||
|
||||
@ -49,6 +49,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (),
|
||||
},
|
||||
@ -102,6 +103,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pinata, winner, clock_pre],
|
||||
vec![
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
..
|
||||
},
|
||||
@ -22,6 +23,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new(account_post)],
|
||||
|
||||
@ -6,6 +6,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: balance,
|
||||
},
|
||||
@ -29,6 +30,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![sender_pre, receiver_pre],
|
||||
vec![
|
||||
|
||||
@ -19,6 +19,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (amount, deadline),
|
||||
},
|
||||
@ -58,6 +59,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![sender_pre, receiver_pre, clock_pre],
|
||||
vec![
|
||||
|
||||
@ -9,6 +9,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (block_validity_window, timestamp_validity_window),
|
||||
},
|
||||
@ -23,6 +24,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new(post)],
|
||||
|
||||
@ -17,6 +17,7 @@ fn main() {
|
||||
let (
|
||||
ProgramInput {
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
pre_states,
|
||||
instruction: (block_validity_window, chained_program_id, chained_block_validity_window),
|
||||
},
|
||||
@ -40,6 +41,7 @@ fn main() {
|
||||
|
||||
ProgramOutput::new(
|
||||
self_program_id,
|
||||
caller_program_id,
|
||||
instruction_words,
|
||||
vec![pre],
|
||||
vec![AccountPostState::new(post)],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user