From 325744044872bbd96e20036ca9b4f68f6c95591e Mon Sep 17 00:00:00 2001 From: Sergio Chouhy Date: Fri, 20 Mar 2026 13:16:52 -0300 Subject: [PATCH] enforce valid window construction --- indexer/service/protocol/src/convert.rs | 21 ++--- nssa/core/src/circuit_io.rs | 2 +- nssa/core/src/program.rs | 85 ++++++++++++++++--- .../privacy_preserving_transaction/message.rs | 2 +- .../transaction.rs | 17 ++-- nssa/src/public_transaction/transaction.rs | 12 +-- nssa/src/state.rs | 18 ++-- .../src/bin/privacy_preserving_circuit.rs | 10 ++- .../guest/src/bin/validity_window.rs | 15 ++-- 9 files changed, 119 insertions(+), 63 deletions(-) diff --git a/indexer/service/protocol/src/convert.rs b/indexer/service/protocol/src/convert.rs index 9e426bd8..ec85d7fb 100644 --- a/indexer/service/protocol/src/convert.rs +++ b/indexer/service/protocol/src/convert.rs @@ -302,13 +302,13 @@ impl From for PrivacyPre .into_iter() .map(|(n, d)| (n.into(), d.into())) .collect(), - validity_window: ValidityWindow(validity_window), + validity_window: ValidityWindow((validity_window.from(), validity_window.to())), } } } impl TryFrom for nssa::privacy_preserving_transaction::message::Message { - type Error = nssa_core::account::data::DataTooBigError; + type Error = nssa::error::NssaError; fn try_from(value: PrivacyPreservingMessage) -> Result { let PrivacyPreservingMessage { @@ -329,7 +329,8 @@ impl TryFrom for nssa::privacy_preserving_transaction: public_post_states: public_post_states .into_iter() .map(TryInto::try_into) - .collect::, _>>()?, + .collect::, _>>() + .map_err(|e| nssa::error::NssaError::InvalidInput(format!("{e}")))?, encrypted_private_post_states: encrypted_private_post_states .into_iter() .map(Into::into) @@ -339,7 +340,10 @@ impl TryFrom for nssa::privacy_preserving_transaction: .into_iter() .map(|(n, d)| (n.into(), d.into())) .collect(), - validity_window: validity_window.0, + validity_window: validity_window + .0 + .try_into() + .map_err(|e| nssa::error::NssaError::InvalidInput(format!("{e}")))?, }) } } @@ -483,14 +487,7 @@ impl TryFrom for nssa::PrivacyPreservingTransactio witness_set, } = value; - Ok(Self::new( - message - .try_into() - .map_err(|err: nssa_core::account::data::DataTooBigError| { - nssa::error::NssaError::InvalidInput(err.to_string()) - })?, - witness_set.try_into()?, - )) + Ok(Self::new(message.try_into()?, witness_set.try_into()?)) } } diff --git a/nssa/core/src/circuit_io.rs b/nssa/core/src/circuit_io.rs index 86d4abec..f9cd9239 100644 --- a/nssa/core/src/circuit_io.rs +++ b/nssa/core/src/circuit_io.rs @@ -102,7 +102,7 @@ mod tests { ), [0xab; 32], )], - validity_window: (Some(1), None), + validity_window: (Some(1), None).try_into().unwrap(), }; let bytes = output.to_bytes(); let output_from_slice: PrivacyPreservingCircuitOutput = from_slice(&bytes).unwrap(); diff --git a/nssa/core/src/program.rs b/nssa/core/src/program.rs index bb89da6e..5cd46432 100644 --- a/nssa/core/src/program.rs +++ b/nssa/core/src/program.rs @@ -1,5 +1,7 @@ use std::collections::HashSet; +#[cfg(feature = "host")] +use borsh::{BorshDeserialize, BorshSerialize}; use risc0_zkvm::{DeserializeOwned, guest::env, serde::Deserializer}; use serde::{Deserialize, Serialize}; @@ -152,7 +154,68 @@ impl AccountPostState { } pub type BlockId = u64; -pub type ValidityWindow = (Option, Option); + +#[derive(Serialize, Deserialize, Clone, Copy)] +#[cfg_attr( + any(feature = "host", test), + derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize) +)] +pub struct ValidityWindow { + from: Option, + to: Option, +} + +impl ValidityWindow { + #[must_use] + pub const fn new_unbounded() -> Self { + Self { + from: None, + to: None, + } + } + + /// Valid for block IDs in the range [from, to), where `from` is included and `to` is excluded. + #[must_use] + pub fn is_valid_for_block_id(&self, id: BlockId) -> bool { + self.from.is_none_or(|start| id >= start) && self.to.is_none_or(|end| id < end) + } + + const fn check_window(&self) -> Result<(), InvalidWindow> { + if let (Some(from_id), Some(until_id)) = (self.from, self.to) + && from_id >= until_id + { + Err(InvalidWindow) + } else { + Ok(()) + } + } + + #[must_use] + pub const fn from(&self) -> Option { + self.from + } + + #[must_use] + pub const fn to(&self) -> Option { + self.to + } +} +impl TryFrom<(Option, Option)> for ValidityWindow { + type Error = InvalidWindow; + + fn try_from(value: (Option, Option)) -> Result { + let this = Self { + from: value.0, + to: value.1, + }; + this.check_window()?; + Ok(this) + } +} + +#[derive(Debug, thiserror::Error, Clone, Copy, PartialEq, Eq)] +#[error("Invalid window")] +pub struct InvalidWindow; #[derive(Serialize, Deserialize, Clone)] #[cfg_attr(any(feature = "host", test), derive(Debug, PartialEq, Eq))] @@ -166,6 +229,8 @@ pub struct ProgramOutput { /// The list of chained calls to other programs. pub chained_calls: Vec, /// The window where the program output is valid. + /// Valid for block IDs in the range [from, to), where `from` is included and `to` is excluded. + /// `None` means unbounded on that side. pub validity_window: ValidityWindow, } @@ -181,7 +246,7 @@ impl ProgramOutput { pre_states, post_states, chained_calls: Vec::new(), - validity_window: (None, None), + validity_window: ValidityWindow::new_unbounded(), } } @@ -195,16 +260,16 @@ impl ProgramOutput { self } - #[must_use] - pub const fn valid_from_id(mut self, id: BlockId) -> Self { - self.validity_window.0 = Some(id); - self + pub fn valid_from_id(mut self, id: Option) -> Result { + self.validity_window.from = id; + self.validity_window.check_window()?; + Ok(self) } - #[must_use] - pub const fn valid_until_id(mut self, id: BlockId) -> Self { - self.validity_window.1 = Some(id); - self + pub fn valid_until_id(mut self, id: Option) -> Result { + self.validity_window.to = id; + self.validity_window.check_window()?; + Ok(self) } } diff --git a/nssa/src/privacy_preserving_transaction/message.rs b/nssa/src/privacy_preserving_transaction/message.rs index a79b1ffa..251bd874 100644 --- a/nssa/src/privacy_preserving_transaction/message.rs +++ b/nssa/src/privacy_preserving_transaction/message.rs @@ -165,7 +165,7 @@ pub mod tests { encrypted_private_post_states, new_commitments, new_nullifiers, - validity_window: (None, None), + validity_window: (None, None).try_into().unwrap(), } } diff --git a/nssa/src/privacy_preserving_transaction/transaction.rs b/nssa/src/privacy_preserving_transaction/transaction.rs index 7db6bdb2..1766af23 100644 --- a/nssa/src/privacy_preserving_transaction/transaction.rs +++ b/nssa/src/privacy_preserving_transaction/transaction.rs @@ -93,6 +93,11 @@ impl PrivacyPreservingTransaction { } } + // Verify validity window + if !message.validity_window.is_valid_for_block_id(block_id) { + return Err(NssaError::OutOfValidityWindow); + } + // Build pre_states for proof verification let public_pre_states: Vec<_> = message .public_account_ids @@ -123,18 +128,6 @@ impl PrivacyPreservingTransaction { // 6. Nullifier uniqueness state.check_nullifiers_are_valid(&message.new_nullifiers)?; - // 7. Verify validity window - if let Some(from_id) = message.validity_window.0 - && block_id < from_id - { - return Err(NssaError::OutOfValidityWindow); - } - if let Some(until_id) = message.validity_window.1 - && until_id < block_id - { - return Err(NssaError::OutOfValidityWindow); - } - Ok(message .public_account_ids .iter() diff --git a/nssa/src/public_transaction/transaction.rs b/nssa/src/public_transaction/transaction.rs index 30d2e92f..d0ffc99a 100644 --- a/nssa/src/public_transaction/transaction.rs +++ b/nssa/src/public_transaction/transaction.rs @@ -192,12 +192,12 @@ impl PublicTransaction { ); // Verify validity window - if let Some(from_id) = program_output.validity_window.0 { - ensure!(from_id <= block_id, NssaError::OutOfValidityWindow); - } - if let Some(until_id) = program_output.validity_window.1 { - ensure!(until_id >= block_id, NssaError::OutOfValidityWindow); - } + ensure!( + program_output + .validity_window + .is_valid_for_block_id(block_id), + NssaError::OutOfValidityWindow + ); for post in program_output .post_states diff --git a/nssa/src/state.rs b/nssa/src/state.rs index ec2950fb..e94eea2f 100644 --- a/nssa/src/state.rs +++ b/nssa/src/state.rs @@ -340,7 +340,7 @@ pub mod tests { Commitment, Nullifier, NullifierPublicKey, NullifierSecretKey, SharedSecretKey, account::{Account, AccountId, AccountWithMetadata, Nonce, data::Data}, encryption::{EphemeralPublicKey, Scalar, ViewingPublicKey}, - program::{BlockId, PdaSeed, ProgramId, ValidityWindow}, + program::{BlockId, PdaSeed, ProgramId}, }; use crate::{ @@ -3013,7 +3013,7 @@ pub mod tests { #[test_case::test_case((None, None), 0; "no bounds - always valid")] #[test_case::test_case((None, None), 100; "no bounds - always valid 2")] fn validity_window_works_in_public_transactions( - validity_window: ValidityWindow, + validity_window: (Option, Option), block_id: BlockId, ) { let validity_window_program = Program::validity_window(); @@ -3035,10 +3035,10 @@ pub mod tests { PublicTransaction::new(message, witness_set) }; let result = state.transition_from_public_transaction(&tx, block_id); - let is_inside_validity_window = match (validity_window.0, validity_window.1) { - (Some(s), Some(e)) => s <= block_id && block_id <= e, + let is_inside_validity_window = match validity_window { + (Some(s), Some(e)) => s <= block_id && block_id < e, (Some(s), None) => s <= block_id, - (None, Some(e)) => block_id <= e, + (None, Some(e)) => block_id < e, (None, None) => true, }; if is_inside_validity_window { @@ -3062,7 +3062,7 @@ pub mod tests { #[test_case::test_case((None, None), 0; "no bounds - always valid")] #[test_case::test_case((None, None), 100; "no bounds - always valid 2")] fn validity_window_works_in_privacy_preserving_transactions( - validity_window: ValidityWindow, + validity_window: (Option, Option), block_id: BlockId, ) { let validity_window_program = Program::validity_window(); @@ -3097,10 +3097,10 @@ pub mod tests { PrivacyPreservingTransaction::new(message, witness_set) }; let result = state.transition_from_privacy_preserving_transaction(&tx, block_id); - let is_inside_validity_window = match (validity_window.0, validity_window.1) { - (Some(s), Some(e)) => s <= block_id && block_id <= e, + let is_inside_validity_window = match validity_window { + (Some(s), Some(e)) => s <= block_id && block_id < e, (Some(s), None) => s <= block_id, - (None, Some(e)) => block_id <= e, + (None, Some(e)) => block_id < e, (None, None) => true, }; if is_inside_validity_window { diff --git a/program_methods/guest/src/bin/privacy_preserving_circuit.rs b/program_methods/guest/src/bin/privacy_preserving_circuit.rs index db767e3e..08872564 100644 --- a/program_methods/guest/src/bin/privacy_preserving_circuit.rs +++ b/program_methods/guest/src/bin/privacy_preserving_circuit.rs @@ -28,17 +28,21 @@ impl ExecutionState { pub fn derive_from_outputs(program_id: ProgramId, program_outputs: Vec) -> Self { let valid_from_id = program_outputs .iter() - .filter_map(|output| output.validity_window.0) + .filter_map(|output| output.validity_window.from()) .max(); let valid_until_id = program_outputs .iter() - .filter_map(|output| output.validity_window.1) + .filter_map(|output| output.validity_window.to()) .min(); + let validity_window = (valid_from_id, valid_until_id).try_into().expect( + "There should be non empty intersection in the program output validity windows", + ); + let mut execution_state = Self { pre_states: Vec::new(), post_states: HashMap::new(), - validity_window: (valid_from_id, valid_until_id), + validity_window, }; let Some(first_output) = program_outputs.first() else { diff --git a/test_program_methods/guest/src/bin/validity_window.rs b/test_program_methods/guest/src/bin/validity_window.rs index dbea9849..03f31073 100644 --- a/test_program_methods/guest/src/bin/validity_window.rs +++ b/test_program_methods/guest/src/bin/validity_window.rs @@ -19,18 +19,15 @@ fn main() { let post = pre.account.clone(); - let mut output = ProgramOutput::new( + let output = ProgramOutput::new( instruction_words, vec![pre], vec![AccountPostState::new(post)], - ); - - if let Some(id) = from_id { - output = output.valid_from_id(id); - } - if let Some(id) = until_id { - output = output.valid_until_id(id); - } + ) + .valid_from_id(from_id) + .unwrap() + .valid_until_id(until_id) + .unwrap(); output.write(); }