add public validity window checks

This commit is contained in:
Sergio Chouhy 2026-03-19 15:03:45 -03:00
parent 895dd942cf
commit 7bbd2dd5d7
25 changed files with 113 additions and 31 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -152,7 +152,7 @@ impl AccountPostState {
}
pub type BlockId = u64;
pub type ValidityRange = (Option<BlockId> , Option<BlockId>);
pub type ValidityWindow = (Option<BlockId>, Option<BlockId>);
#[derive(Serialize, Deserialize, Clone)]
#[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))]
@ -163,24 +163,48 @@ pub struct ProgramOutput {
pub pre_states: Vec<AccountWithMetadata>,
pub post_states: Vec<AccountPostState>,
pub chained_calls: Vec<ChainedCall>,
pub validity_range: ValidityRange,
pub validity_window: ValidityWindow,
}
impl ProgramOutput {
#[must_use]
pub const fn new(
instruction_data: InstructionData,
pre_states: Vec<AccountWithMetadata>,
post_states: Vec<AccountPostState>,
) -> Self {
Self {
instruction_data,
pre_states,
post_states,
chained_calls: Vec::new(),
validity_window: (None, None),
}
}
pub fn write(self) {
env::commit(&self);
}
#[must_use]
pub fn with_chained_calls(mut self, chained_calls: Vec<ChainedCall>) -> Self {
self.chained_calls = chained_calls;
self
}
#[must_use]
pub const fn valid_from_id(mut self, id: BlockId) -> Self {
self.validity_range.0 = Some(id);
self.validity_window.0 = Some(id);
self
}
#[must_use]
pub const fn valid_until_id(mut self, id: BlockId) -> Self {
self.validity_range.1 = Some(id);
self.validity_window.1 = Some(id);
self
}
}
/// Representation of a number as `lo + hi * 2^128`.
#[derive(PartialEq, Eq)]
struct WrappedBalanceSum {
@ -243,14 +267,7 @@ pub fn write_nssa_outputs(
pre_states: Vec<AccountWithMetadata>,
post_states: Vec<AccountPostState>,
) {
let output = ProgramOutput {
instruction_data,
pre_states,
post_states,
chained_calls: Vec::new(),
validity_range: (None, None)
};
env::commit(&output);
ProgramOutput::new(instruction_data, pre_states, post_states).write();
}
pub fn write_nssa_outputs_with_chained_call(
@ -259,14 +276,9 @@ pub fn write_nssa_outputs_with_chained_call(
post_states: Vec<AccountPostState>,
chained_calls: Vec<ChainedCall>,
) {
let output = ProgramOutput {
instruction_data,
pre_states,
post_states,
chained_calls,
validity_range: (None, None)
};
env::commit(&output);
ProgramOutput::new(instruction_data, pre_states, post_states)
.with_chained_calls(chained_calls)
.write();
}
/// Validates well-behaved program execution.

View File

@ -69,6 +69,9 @@ pub enum NssaError {
#[error("Max account nonce reached")]
MaxAccountNonceReached,
#[error("Execution outside of the validity window")]
OutOfValidityWindow,
}
#[cfg(test)]

View File

@ -284,6 +284,14 @@ mod tests {
// `program_methods`
Self::new(MODIFIED_TRANSFER_ELF.to_vec()).unwrap()
}
#[must_use]
pub fn validity_window() -> Self {
use test_program_methods::VALIDITY_WINDOW_ELF;
// This unwrap won't panic since the `VALIDITY_WINDOW_ELF` comes from risc0 build of
// `program_methods`
Self::new(VALIDITY_WINDOW_ELF.to_vec()).unwrap()
}
}
#[test]

View File

@ -4,7 +4,7 @@ use borsh::{BorshDeserialize, BorshSerialize};
use log::debug;
use nssa_core::{
account::{Account, AccountId, AccountWithMetadata},
program::{ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
program::{BlockId, ChainedCall, DEFAULT_PROGRAM_ID, validate_execution},
};
use sha2::{Digest as _, digest::FixedOutput as _};
@ -70,6 +70,7 @@ impl PublicTransaction {
pub(crate) fn validate_and_produce_public_state_diff(
&self,
state: &V02State,
block_id: BlockId,
) -> Result<HashMap<AccountId, Account>, NssaError> {
let message = self.message();
let witness_set = self.witness_set();
@ -190,6 +191,14 @@ impl PublicTransaction {
NssaError::InvalidProgramBehavior
);
// Verify validity window
if let Some(from_id) = program_output.validity_window.0 {
ensure!(from_id <= block_id, NssaError::OutOfValidityWindow);
}
if let Some(until_id) = program_output.validity_window.1 {
ensure!(until_id >= block_id, NssaError::OutOfValidityWindow);
}
for post in program_output
.post_states
.iter_mut()
@ -359,7 +368,7 @@ pub mod tests {
let witness_set = WitnessSet::for_message(&message, &[&key1, &key1]);
let tx = PublicTransaction::new(message, witness_set);
let result = tx.validate_and_produce_public_state_diff(&state);
let result = tx.validate_and_produce_public_state_diff(&state, 1);
assert!(matches!(result, Err(NssaError::InvalidInput(_))));
}
@ -379,7 +388,7 @@ pub mod tests {
let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]);
let tx = PublicTransaction::new(message, witness_set);
let result = tx.validate_and_produce_public_state_diff(&state);
let result = tx.validate_and_produce_public_state_diff(&state, 1);
assert!(matches!(result, Err(NssaError::InvalidInput(_))));
}
@ -400,7 +409,7 @@ pub mod tests {
let mut witness_set = WitnessSet::for_message(&message, &[&key1, &key2]);
witness_set.signatures_and_public_keys[0].0 = Signature::new_for_tests([1; 64]);
let tx = PublicTransaction::new(message, witness_set);
let result = tx.validate_and_produce_public_state_diff(&state);
let result = tx.validate_and_produce_public_state_diff(&state, 1);
assert!(matches!(result, Err(NssaError::InvalidInput(_))));
}
@ -420,7 +429,7 @@ pub mod tests {
let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]);
let tx = PublicTransaction::new(message, witness_set);
let result = tx.validate_and_produce_public_state_diff(&state);
let result = tx.validate_and_produce_public_state_diff(&state, 1);
assert!(matches!(result, Err(NssaError::InvalidInput(_))));
}
@ -436,7 +445,7 @@ pub mod tests {
let witness_set = WitnessSet::for_message(&message, &[&key1, &key2]);
let tx = PublicTransaction::new(message, witness_set);
let result = tx.validate_and_produce_public_state_diff(&state);
let result = tx.validate_and_produce_public_state_diff(&state, 1);
assert!(matches!(result, Err(NssaError::InvalidInput(_))));
}
}

View File

@ -2,7 +2,9 @@ use std::collections::{BTreeSet, HashMap, HashSet};
use borsh::{BorshDeserialize, BorshSerialize};
use nssa_core::{
account::{Account, AccountId, Nonce}, program::{BlockId, ProgramId}, Commitment, CommitmentSetDigest, MembershipProof, Nullifier, DUMMY_COMMITMENT
Commitment, CommitmentSetDigest, DUMMY_COMMITMENT, MembershipProof, Nullifier,
account::{Account, AccountId, Nonce},
program::{BlockId, ProgramId},
};
use crate::{
@ -155,9 +157,9 @@ impl V02State {
pub fn transition_from_public_transaction(
&mut self,
tx: &PublicTransaction,
_block_id: BlockId,
block_id: BlockId,
) -> Result<(), NssaError> {
let state_diff = tx.validate_and_produce_public_state_diff(self)?;
let state_diff = tx.validate_and_produce_public_state_diff(self, block_id)?;
#[expect(
clippy::iter_over_hash_type,
@ -338,7 +340,7 @@ pub mod tests {
Commitment, Nullifier, NullifierPublicKey, NullifierSecretKey, SharedSecretKey,
account::{Account, AccountId, AccountWithMetadata, Nonce, data::Data},
encryption::{EphemeralPublicKey, Scalar, ViewingPublicKey},
program::{PdaSeed, ProgramId},
program::{BlockId, PdaSeed, ProgramId, ValidityWindow},
};
use crate::{
@ -373,6 +375,7 @@ pub mod tests {
self.insert_program(Program::amm());
self.insert_program(Program::claimer());
self.insert_program(Program::changer_claimer());
self.insert_program(Program::validity_window());
self
}
@ -2996,6 +2999,53 @@ pub mod tests {
assert!(matches!(result, Err(NssaError::CircuitProvingError(_))));
}
#[test_case::test_case((Some(1), Some(3)), 3; "at upper bound")]
#[test_case::test_case((Some(1), Some(3)), 2; "inside range")]
#[test_case::test_case((Some(1), Some(3)), 0; "below range")]
#[test_case::test_case((Some(1), Some(3)), 1; "at lower bound")]
#[test_case::test_case((Some(1), Some(3)), 4; "above range")]
#[test_case::test_case((Some(1), None), 1; "lower bound only - at bound")]
#[test_case::test_case((Some(1), None), 10; "lower bound only - above")]
#[test_case::test_case((Some(1), None), 0; "lower bound only - below")]
#[test_case::test_case((None, Some(3)), 3; "upper bound only - at bound")]
#[test_case::test_case((None, Some(3)), 0; "upper bound only - below")]
#[test_case::test_case((None, Some(3)), 4; "upper bound only - above")]
#[test_case::test_case((None, None), 0; "no bounds - always valid")]
#[test_case::test_case((None, None), 100; "no bounds - always valid 2")]
fn validity_window_works(validity_window: ValidityWindow, block_id: BlockId) {
let validity_window_program = Program::validity_window();
let account_keys = test_public_account_keys_1();
let pre = AccountWithMetadata::new(Account::default(), false, account_keys.account_id());
let mut state =
V02State::new_with_genesis_accounts(&[], &[]).with_test_programs();
let tx = {
let account_ids = vec![pre.account_id];
let nonces = vec![];
let program_id = validity_window_program.id();
let message = public_transaction::Message::try_new(
program_id,
account_ids,
nonces,
validity_window,
)
.unwrap();
let witness_set = public_transaction::WitnessSet::for_message(&message, &[]);
PublicTransaction::new(message, witness_set)
};
let result = state.transition_from_public_transaction(&tx, block_id);
let is_inside_validity_window = match (validity_window.0, validity_window.1) {
(Some(s), Some(e)) => s <= block_id && block_id <= e,
(Some(s), None) => s <= block_id,
(None, Some(e)) => block_id <= e,
(None, None) => true,
};
if is_inside_validity_window {
assert!(result.is_ok());
} else {
assert!(matches!(result, Err(NssaError::OutOfValidityWindow)))
}
}
#[test]
fn state_serialization_roundtrip() {
let account_id_1 = AccountId::new([1; 32]);