From efff80de678701a74f02543950e1abb3dc7b9f99 Mon Sep 17 00:00:00 2001 From: holisticode <88287+holisticode@users.noreply.github.com> Date: Tue, 3 Sep 2024 03:25:12 -0500 Subject: [PATCH] Sampling backend implementation (#709) * implemented backend; correct start of monitoring thread needs to be finished * Fix kzgrs backend (#710) * abstract rng * fix rng abstraction * replaced subnets vector with HashSet, fixed bugs, added tests * addressed PR comments * fix clippy warnings * Rename TrackingState -> SamplingState * Short circuit failure on init error --------- Co-authored-by: Daniel Sanchez Co-authored-by: danielSanchezQ <3danimanimal@gmail.com> --- .../core/src/protocols/sampling/behaviour.rs | 7 + .../data-availability/sampling/Cargo.toml | 1 + .../sampling/src/backend/kzgrs.rs | 366 ++++++++++++++++++ .../sampling/src/backend/mod.rs | 20 +- .../data-availability/sampling/src/lib.rs | 50 ++- 5 files changed, 425 insertions(+), 19 deletions(-) diff --git a/nomos-da/network/core/src/protocols/sampling/behaviour.rs b/nomos-da/network/core/src/protocols/sampling/behaviour.rs index 10679d1b..53f8184f 100644 --- a/nomos-da/network/core/src/protocols/sampling/behaviour.rs +++ b/nomos-da/network/core/src/protocols/sampling/behaviour.rs @@ -73,6 +73,13 @@ impl SamplingError { SamplingError::ResponseChannel { peer_id, .. } => peer_id, } } + + pub fn blob_id(&self) -> Option<&BlobId> { + match self { + SamplingError::Deserialize { blob_id, .. } => Some(blob_id), + _ => None, + } + } } impl Clone for SamplingError { diff --git a/nomos-services/data-availability/sampling/Cargo.toml b/nomos-services/data-availability/sampling/Cargo.toml index ae1c38a0..d917e99f 100644 --- a/nomos-services/data-availability/sampling/Cargo.toml +++ b/nomos-services/data-availability/sampling/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" async-trait = "0.1" bytes = "1.2" futures = "0.3" +hex = "0.4.3" kzgrs-backend = { path = "../../../nomos-da/kzgrs-backend" } libp2p-identity = { version = "0.2" } nomos-core = { path = "../../../nomos-core" } diff --git a/nomos-services/data-availability/sampling/src/backend/kzgrs.rs b/nomos-services/data-availability/sampling/src/backend/kzgrs.rs index 139597f9..2d7bb76e 100644 --- a/nomos-services/data-availability/sampling/src/backend/kzgrs.rs +++ b/nomos-services/data-availability/sampling/src/backend/kzgrs.rs @@ -1,2 +1,368 @@ +// std +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::fmt::Debug; +use std::time::{Duration, Instant}; +// crates +use hex; +use rand::distributions::Standard; +use rand::prelude::*; +use tokio::time; +use tokio::time::Interval; +use kzgrs_backend::common::blob::DaBlob; + +// +// internal +use crate::{backend::SamplingState, DaSamplingServiceBackend}; +use nomos_core::da::BlobId; +use nomos_da_network_core::SubnetworkId; + +#[derive(Clone)] +pub struct SamplingContext { + subnets: HashSet, + started: Instant, +} + +#[derive(Debug, Clone)] +pub struct KzgrsDaSamplerSettings { + pub num_samples: u16, + pub old_blobs_check_interval: Duration, + pub blobs_validity_duration: Duration, +} + +pub struct KzgrsDaSampler { + settings: KzgrsDaSamplerSettings, + validated_blobs: BTreeSet, + pending_sampling_blobs: HashMap, + rng: R, +} + +impl KzgrsDaSampler { + fn prune_by_time(&mut self) { + self.pending_sampling_blobs.retain(|_blob_id, context| { + context.started.elapsed() < self.settings.blobs_validity_duration + }); + } +} + +#[async_trait::async_trait] +impl DaSamplingServiceBackend for KzgrsDaSampler { + type Settings = KzgrsDaSamplerSettings; + type BlobId = BlobId; + type Blob = DaBlob; + + fn new(settings: Self::Settings, rng: R) -> Self { + let bt: BTreeSet = BTreeSet::new(); + Self { + settings, + validated_blobs: bt, + pending_sampling_blobs: HashMap::new(), + rng, + } + } + + async fn prune_interval(&self) -> Interval { + time::interval(self.settings.old_blobs_check_interval) + } + + async fn get_validated_blobs(&self) -> BTreeSet { + self.validated_blobs.clone() + } + + async fn mark_in_block(&mut self, blobs_ids: &[Self::BlobId]) { + for id in blobs_ids { + self.pending_sampling_blobs.remove(id); + self.validated_blobs.remove(id); + } + } + + async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob) { + if let Some(ctx) = self.pending_sampling_blobs.get_mut(&blob_id) { + tracing::info!( + "subnet {} for blob id {} has been successfully sampled", + blob.column_idx, + hex::encode(blob_id) + ); + ctx.subnets.insert(blob.column_idx as SubnetworkId); + + // sampling of this blob_id terminated successfully + if ctx.subnets.len() == self.settings.num_samples as usize { + self.validated_blobs.insert(blob_id); + tracing::info!( + "blob_id {} has been successfully sampled", + hex::encode(blob_id) + ); + // cleanup from pending samplings + self.pending_sampling_blobs.remove(&blob_id); + } + } + } + + async fn handle_sampling_error(&mut self, blob_id: Self::BlobId) { + // If it fails a single time we consider it failed. + // We may want to abstract the sampling policies somewhere else at some point if we + // need to get fancier than this + self.pending_sampling_blobs.remove(&blob_id); + self.validated_blobs.remove(&blob_id); + } + + async fn init_sampling(&mut self, blob_id: Self::BlobId) -> SamplingState { + if self.pending_sampling_blobs.contains_key(&blob_id) { + return SamplingState::Tracking; + } + if self.validated_blobs.contains(&blob_id) { + return SamplingState::Terminated; + } + + let subnets: Vec = Standard + .sample_iter(&mut self.rng) + .take(self.settings.num_samples as usize) + .collect(); + let ctx: SamplingContext = SamplingContext { + subnets: HashSet::new(), + started: Instant::now(), + }; + self.pending_sampling_blobs.insert(blob_id, ctx); + SamplingState::Init(subnets) + } + + fn prune(&mut self) { + self.prune_by_time() + } +} + +#[cfg(test)] +mod test { + + use std::collections::HashSet; + use std::time::{Duration, Instant}; + + use rand::prelude::*; + use rand::rngs::StdRng; + + use crate::backend::kzgrs::{ + DaSamplingServiceBackend, KzgrsDaSampler, KzgrsDaSamplerSettings, SamplingContext, + SamplingState, + }; + use kzgrs_backend::common::{blob::DaBlob, Column}; + use nomos_core::da::BlobId; + + fn create_sampler(subnet_num: usize) -> KzgrsDaSampler { + let settings = KzgrsDaSamplerSettings { + num_samples: subnet_num as u16, + old_blobs_check_interval: Duration::from_millis(20), + blobs_validity_duration: Duration::from_millis(10), + }; + let rng = StdRng::from_entropy(); + KzgrsDaSampler::new(settings, rng) + } + + #[tokio::test] + async fn test_sampler() { + // fictitious number of subnets + let subnet_num: usize = 42; + + // create a sampler instance + let sampler = &mut create_sampler(subnet_num); + + // create some blobs and blob_ids + let b1: BlobId = sampler.rng.gen(); + let b2: BlobId = sampler.rng.gen(); + let blob = DaBlob { + column_idx: 42, + column: Column(vec![]), + column_commitment: Default::default(), + aggregated_column_commitment: Default::default(), + aggregated_column_proof: Default::default(), + rows_commitments: vec![], + rows_proofs: vec![], + }; + let blob2 = blob.clone(); + let mut blob3 = blob2.clone(); + + // at start everything should be empty + assert!(sampler.pending_sampling_blobs.is_empty()); + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.get_validated_blobs().await.is_empty()); + + // start sampling for b1 + let SamplingState::Init(subnets_to_sample) = sampler.init_sampling(b1).await else { + panic!("unexpected return value") + }; + assert!(subnets_to_sample.len() == subnet_num); + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.len() == 1); + + // start sampling for b2 + let SamplingState::Init(subnets_to_sample2) = sampler.init_sampling(b2).await else { + panic!("unexpected return value") + }; + assert!(subnets_to_sample2.len() == subnet_num); + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.len() == 2); + + // mark in block for both + // collections should be reset + sampler.mark_in_block(&[b1, b2]).await; + assert!(sampler.pending_sampling_blobs.is_empty()); + assert!(sampler.validated_blobs.is_empty()); + + // because they're reset, we need to restart sampling for the test + _ = sampler.init_sampling(b1).await; + _ = sampler.init_sampling(b2).await; + + // handle ficticious error for b2 + // b2 should be gone, b1 still around + sampler.handle_sampling_error(b2).await; + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.len() == 1); + assert!(sampler.pending_sampling_blobs.contains_key(&b1)); + + // handle ficticious sampling success for b1 + // should still just have one pending blob, no validated blobs yet, + // and one subnet added to blob + sampler.handle_sampling_success(b1, blob).await; + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.len() == 1); + println!( + "{}", + sampler + .pending_sampling_blobs + .get(&b1) + .unwrap() + .subnets + .len() + ); + assert!( + sampler + .pending_sampling_blobs + .get(&b1) + .unwrap() + .subnets + .len() + == 1 + ); + + // handle_success for always the same subnet + // by adding number of subnets time the same subnet + // (column_idx did not change) + // should not change, still just one sampling blob, + // no validated blobs and one subnet + for _ in 1..subnet_num { + let b = blob2.clone(); + sampler.handle_sampling_success(b1, b).await; + } + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.len() == 1); + assert!( + sampler + .pending_sampling_blobs + .get(&b1) + .unwrap() + .subnets + .len() + == 1 + ); + + // handle_success for up to subnet size minus one subnet + // should still not change anything + // but subnets len is now subnet size minus one + // we already added subnet 42 + for i in 1..(subnet_num - 1) { + let mut b = blob2.clone(); + b.column_idx = i as u16; + sampler.handle_sampling_success(b1, b).await; + } + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.len() == 1); + assert!( + sampler + .pending_sampling_blobs + .get(&b1) + .unwrap() + .subnets + .len() + == subnet_num - 1 + ); + + // now add the last subnet! + // we should have all subnets set, + // and the validated blobs should now have that blob + // pending blobs should now be empty + blob3.column_idx = (subnet_num - 1) as u16; + sampler.handle_sampling_success(b1, blob3).await; + assert!(sampler.pending_sampling_blobs.is_empty()); + assert!(sampler.validated_blobs.len() == 1); + assert!(sampler.validated_blobs.contains(&b1)); + // these checks are redundant but better safe than sorry + assert!(sampler.get_validated_blobs().await.len() == 1); + assert!(sampler.get_validated_blobs().await.contains(&b1)); + + // run mark_in_block for the same blob + // should return empty for everything + sampler.mark_in_block(&[b1]).await; + assert!(sampler.validated_blobs.is_empty()); + assert!(sampler.pending_sampling_blobs.is_empty()); + } + + #[tokio::test] + async fn test_pruning() { + let mut sampler = create_sampler(42); + + // create some sampling contexes + // first set will go through as in time + let ctx1 = SamplingContext { + subnets: HashSet::new(), + started: Instant::now(), + }; + let ctx2 = ctx1.clone(); + let ctx3 = ctx1.clone(); + + // second set: will fail for expired + let ctx11 = SamplingContext { + subnets: HashSet::new(), + started: Instant::now() - Duration::from_secs(1), + }; + let ctx12 = ctx11.clone(); + let ctx13 = ctx11.clone(); + + // create a couple blob ids + let b1: BlobId = sampler.rng.gen(); + let b2: BlobId = sampler.rng.gen(); + let b3: BlobId = sampler.rng.gen(); + + // insert first blob + // pruning should have no effect + assert!(sampler.pending_sampling_blobs.is_empty()); + sampler.pending_sampling_blobs.insert(b1, ctx1); + sampler.prune(); + assert!(sampler.pending_sampling_blobs.len() == 1); + + // insert second blob + // pruning should have no effect + sampler.pending_sampling_blobs.insert(b2, ctx2); + sampler.prune(); + assert!(sampler.pending_sampling_blobs.len() == 2); + + // insert third blob + // pruning should have no effect + sampler.pending_sampling_blobs.insert(b3, ctx3); + sampler.prune(); + assert!(sampler.pending_sampling_blobs.len() == 3); + + // insert fake expired blobs + // pruning these should now decrease pending blobx every time + sampler.pending_sampling_blobs.insert(b1, ctx11); + sampler.prune(); + assert!(sampler.pending_sampling_blobs.len() == 2); + + sampler.pending_sampling_blobs.insert(b2, ctx12); + sampler.prune(); + assert!(sampler.pending_sampling_blobs.len() == 1); + + sampler.pending_sampling_blobs.insert(b3, ctx13); + sampler.prune(); + assert!(sampler.pending_sampling_blobs.is_empty()); + } +} diff --git a/nomos-services/data-availability/sampling/src/backend/mod.rs b/nomos-services/data-availability/sampling/src/backend/mod.rs index 4d508843..0cc6e758 100644 --- a/nomos-services/data-availability/sampling/src/backend/mod.rs +++ b/nomos-services/data-availability/sampling/src/backend/mod.rs @@ -1,21 +1,33 @@ +pub mod kzgrs; + // std use std::collections::BTreeSet; // crates +use rand::Rng; +use tokio::time::Interval; // // internal use nomos_da_network_core::SubnetworkId; +pub enum SamplingState { + Init(Vec), + Tracking, + Terminated, +} + #[async_trait::async_trait] -pub trait DaSamplingServiceBackend { +pub trait DaSamplingServiceBackend { type Settings; type BlobId; type Blob; - fn new(settings: Self::Settings) -> Self; + fn new(settings: Self::Settings, rng: R) -> Self; async fn get_validated_blobs(&self) -> BTreeSet; - async fn mark_in_block(&mut self, blobs_id: &[Self::BlobId]); + async fn mark_in_block(&mut self, blobs_ids: &[Self::BlobId]); async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob); async fn handle_sampling_error(&mut self, blob_id: Self::BlobId); - async fn init_sampling(&mut self, blob_id: Self::BlobId) -> Vec; + async fn init_sampling(&mut self, blob_id: Self::BlobId) -> SamplingState; + async fn prune_interval(&self) -> Interval; + fn prune(&mut self); } diff --git a/nomos-services/data-availability/sampling/src/lib.rs b/nomos-services/data-availability/sampling/src/lib.rs index 67bf5b8d..9eea10d9 100644 --- a/nomos-services/data-availability/sampling/src/lib.rs +++ b/nomos-services/data-availability/sampling/src/lib.rs @@ -6,10 +6,11 @@ use std::collections::BTreeSet; use std::fmt::Debug; // crates +use rand::prelude::*; use tokio_stream::StreamExt; use tracing::{error, span, Instrument, Level}; // internal -use backend::DaSamplingServiceBackend; +use backend::{DaSamplingServiceBackend, SamplingState}; use kzgrs_backend::common::blob::DaBlob; use network::NetworkAdapter; use nomos_core::da::BlobId; @@ -46,9 +47,10 @@ pub struct DaSamplingServiceSettings { impl RelayMessage for DaSamplingServiceMsg {} -pub struct DaSamplingService +pub struct DaSamplingService where - Backend: DaSamplingServiceBackend + Send, + R: SeedableRng + RngCore, + Backend: DaSamplingServiceBackend + Send, Backend::Settings: Clone, Backend::Blob: Debug + 'static, Backend::BlobId: Debug + 'static, @@ -60,9 +62,10 @@ where sampler: Backend, } -impl DaSamplingService +impl DaSamplingService where - Backend: DaSamplingServiceBackend + Send + 'static, + R: SeedableRng + RngCore, + Backend: DaSamplingServiceBackend + Send + 'static, Backend::Settings: Clone, N: NetworkAdapter + Send + 'static, N::Settings: Clone, @@ -89,12 +92,16 @@ where ) { match msg { DaSamplingServiceMsg::TriggerSampling { blob_id } => { - let sampling_subnets = sampler.init_sampling(blob_id).await; - if let Err(e) = network_adapter - .start_sampling(blob_id, &sampling_subnets) - .await + if let SamplingState::Init(sampling_subnets) = sampler.init_sampling(blob_id).await { - error!("Error sampling for BlobId: {blob_id:?}: {e}"); + if let Err(e) = network_adapter + .start_sampling(blob_id, &sampling_subnets) + .await + { + // we can short circuit the failure from beginning + sampler.handle_sampling_error(blob_id).await; + error!("Error sampling for BlobId: {blob_id:?}: {e}"); + } } } DaSamplingServiceMsg::GetValidatedBlobs { reply_channel } => { @@ -115,15 +122,20 @@ where sampler.handle_sampling_success(blob_id, *blob).await; } SamplingEvent::SamplingError { error } => { + if let Some(blob_id) = error.blob_id() { + sampler.handle_sampling_error(*blob_id).await; + return; + } error!("Error while sampling: {error}"); } } } } -impl ServiceData for DaSamplingService +impl ServiceData for DaSamplingService where - Backend: DaSamplingServiceBackend + Send, + R: SeedableRng + RngCore, + Backend: DaSamplingServiceBackend + Send, Backend::Settings: Clone, Backend::Blob: Debug + 'static, Backend::BlobId: Debug + 'static, @@ -138,9 +150,10 @@ where } #[async_trait::async_trait] -impl ServiceCore for DaSamplingService +impl ServiceCore for DaSamplingService where - Backend: DaSamplingServiceBackend + Send + Sync + 'static, + R: SeedableRng + RngCore, + Backend: DaSamplingServiceBackend + Send + Sync + 'static, Backend::Settings: Clone + Send + Sync + 'static, N: NetworkAdapter + Send + Sync + 'static, N::Settings: Clone + Send + Sync + 'static, @@ -151,11 +164,12 @@ where } = service_state.settings_reader.get_updated_settings(); let network_relay = service_state.overwatch_handle.relay(); + let rng = R::from_entropy(); Ok(Self { network_relay, service_state, - sampler: Backend::new(sampling_settings), + sampler: Backend::new(sampling_settings, rng), }) } @@ -175,6 +189,7 @@ where let mut network_adapter = N::new(network_relay).await; let mut sampling_message_stream = network_adapter.listen_to_sampling_messages().await?; + let mut next_prune_tick = sampler.prune_interval().await; let mut lifecycle_stream = service_state.lifecycle_handle.message_stream(); async { @@ -191,6 +206,11 @@ where break; } } + // cleanup not on time samples + _ = next_prune_tick.tick() => { + sampler.prune(); + } + } } }