use std::ops::Deref; use borsh::{BorshDeserialize, BorshSerialize}; use serde::{Deserialize, Serialize}; pub const DATA_MAX_LENGTH_IN_BYTES: usize = 100 * 1024; // 100 KiB #[derive(Default, Clone, PartialEq, Eq, Serialize, BorshSerialize)] #[cfg_attr(any(feature = "host", test), derive(Debug))] pub struct Data(Vec); impl Data { pub fn into_inner(self) -> Vec { self.0 } #[cfg(feature = "host")] pub fn from_cursor( cursor: &mut std::io::Cursor<&[u8]>, ) -> Result { use std::io::Read as _; let mut u32_bytes = [0u8; 4]; cursor.read_exact(&mut u32_bytes)?; let data_length = u32::from_le_bytes(u32_bytes); if data_length as usize > DATA_MAX_LENGTH_IN_BYTES { return Err( std::io::Error::new(std::io::ErrorKind::InvalidData, DataTooBigError).into(), ); } let mut data = vec![0; data_length as usize]; cursor.read_exact(&mut data)?; Ok(Self(data)) } } #[derive(Debug, thiserror::Error, Clone, Copy, PartialEq, Eq)] #[error("data length exceeds maximum allowed length of {DATA_MAX_LENGTH_IN_BYTES} bytes")] pub struct DataTooBigError; impl From for Vec { fn from(data: Data) -> Self { data.0 } } impl TryFrom> for Data { type Error = DataTooBigError; fn try_from(value: Vec) -> Result { if value.len() > DATA_MAX_LENGTH_IN_BYTES { Err(DataTooBigError) } else { Ok(Self(value)) } } } impl Deref for Data { type Target = [u8]; fn deref(&self) -> &Self::Target { &self.0 } } impl AsRef<[u8]> for Data { fn as_ref(&self) -> &[u8] { &self.0 } } impl<'de> Deserialize<'de> for Data { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { /// Data deserialization visitor. /// /// Compared to a simple deserialization into a `Vec`, this visitor enforces /// early length check defined by [`DATA_MAX_LENGTH_IN_BYTES`]. struct DataVisitor; impl<'de> serde::de::Visitor<'de> for DataVisitor { type Value = Data; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!( formatter, "a byte array with length not exceeding {} bytes", DATA_MAX_LENGTH_IN_BYTES ) } fn visit_seq(self, mut seq: A) -> Result where A: serde::de::SeqAccess<'de>, { let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(0).min(DATA_MAX_LENGTH_IN_BYTES)); while let Some(value) = seq.next_element()? { if vec.len() >= DATA_MAX_LENGTH_IN_BYTES { return Err(serde::de::Error::custom(DataTooBigError)); } vec.push(value); } Ok(Data(vec)) } } deserializer.deserialize_seq(DataVisitor) } } impl BorshDeserialize for Data { fn deserialize_reader(reader: &mut R) -> std::io::Result { // Implementation adapted from `impl BorshDeserialize for Vec` let len = u32::deserialize_reader(reader)?; match len { 0 => Ok(Self::default()), len if len as usize > DATA_MAX_LENGTH_IN_BYTES => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, DataTooBigError, )), len => { let vec_bytes = u8::vec_from_reader(len, reader)? .expect("can't be None in current borsh crate implementation"); Ok(Self(vec_bytes)) } } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_data_max_length_allowed() { let max_vec = vec![0u8; DATA_MAX_LENGTH_IN_BYTES]; let result = Data::try_from(max_vec); assert!(result.is_ok()); } #[test] fn test_data_too_big_error() { let big_vec = vec![0u8; DATA_MAX_LENGTH_IN_BYTES + 1]; let result = Data::try_from(big_vec); assert!(matches!(result, Err(DataTooBigError))); } #[test] fn test_borsh_deserialize_exceeding_limit_error() { let too_big_data = vec![0u8; DATA_MAX_LENGTH_IN_BYTES + 1]; let mut serialized = Vec::new(); <_ as BorshSerialize>::serialize(&too_big_data, &mut serialized).unwrap(); let result = ::deserialize(&mut serialized.as_ref()); assert!(result.is_err()); } #[test] fn test_json_deserialize_exceeding_limit_error() { let data = vec![0u8; DATA_MAX_LENGTH_IN_BYTES + 1]; let json = serde_json::to_string(&data).unwrap(); let result: Result = serde_json::from_str(&json); assert!(result.is_err()); } }