Sampling backend implementation (#709)

* implemented backend; correct start of monitoring thread needs to be finished

* Fix kzgrs backend (#710)

* abstract rng

* fix rng abstraction

* replaced subnets vector with HashSet, fixed bugs, added tests

* addressed PR comments

* fix clippy warnings

* Rename TrackingState -> SamplingState

* Short circuit failure on init error

---------

Co-authored-by: Daniel Sanchez <sanchez.quiros.daniel@gmail.com>
Co-authored-by: danielSanchezQ <3danimanimal@gmail.com>
This commit is contained in:
holisticode 2024-09-03 03:25:12 -05:00 committed by GitHub
parent 7dc2111341
commit efff80de67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 425 additions and 19 deletions

View File

@ -73,6 +73,13 @@ impl SamplingError {
SamplingError::ResponseChannel { peer_id, .. } => peer_id, SamplingError::ResponseChannel { peer_id, .. } => peer_id,
} }
} }
pub fn blob_id(&self) -> Option<&BlobId> {
match self {
SamplingError::Deserialize { blob_id, .. } => Some(blob_id),
_ => None,
}
}
} }
impl Clone for SamplingError { impl Clone for SamplingError {

View File

@ -7,6 +7,7 @@ edition = "2021"
async-trait = "0.1" async-trait = "0.1"
bytes = "1.2" bytes = "1.2"
futures = "0.3" futures = "0.3"
hex = "0.4.3"
kzgrs-backend = { path = "../../../nomos-da/kzgrs-backend" } kzgrs-backend = { path = "../../../nomos-da/kzgrs-backend" }
libp2p-identity = { version = "0.2" } libp2p-identity = { version = "0.2" }
nomos-core = { path = "../../../nomos-core" } nomos-core = { path = "../../../nomos-core" }

View File

@ -1,2 +1,368 @@
// std
use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Debug;
use std::time::{Duration, Instant};
// crates
use hex;
use rand::distributions::Standard;
use rand::prelude::*;
use tokio::time;
use tokio::time::Interval;
use kzgrs_backend::common::blob::DaBlob;
//
// internal
use crate::{backend::SamplingState, DaSamplingServiceBackend};
use nomos_core::da::BlobId;
use nomos_da_network_core::SubnetworkId;
#[derive(Clone)]
pub struct SamplingContext {
subnets: HashSet<SubnetworkId>,
started: Instant,
}
#[derive(Debug, Clone)]
pub struct KzgrsDaSamplerSettings {
pub num_samples: u16,
pub old_blobs_check_interval: Duration,
pub blobs_validity_duration: Duration,
}
pub struct KzgrsDaSampler<R: Rng> {
settings: KzgrsDaSamplerSettings,
validated_blobs: BTreeSet<BlobId>,
pending_sampling_blobs: HashMap<BlobId, SamplingContext>,
rng: R,
}
impl<R: Rng> KzgrsDaSampler<R> {
fn prune_by_time(&mut self) {
self.pending_sampling_blobs.retain(|_blob_id, context| {
context.started.elapsed() < self.settings.blobs_validity_duration
});
}
}
#[async_trait::async_trait]
impl<R: Rng + Sync + Send> DaSamplingServiceBackend<R> for KzgrsDaSampler<R> {
type Settings = KzgrsDaSamplerSettings;
type BlobId = BlobId;
type Blob = DaBlob;
fn new(settings: Self::Settings, rng: R) -> Self {
let bt: BTreeSet<BlobId> = BTreeSet::new();
Self {
settings,
validated_blobs: bt,
pending_sampling_blobs: HashMap::new(),
rng,
}
}
async fn prune_interval(&self) -> Interval {
time::interval(self.settings.old_blobs_check_interval)
}
async fn get_validated_blobs(&self) -> BTreeSet<Self::BlobId> {
self.validated_blobs.clone()
}
async fn mark_in_block(&mut self, blobs_ids: &[Self::BlobId]) {
for id in blobs_ids {
self.pending_sampling_blobs.remove(id);
self.validated_blobs.remove(id);
}
}
async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob) {
if let Some(ctx) = self.pending_sampling_blobs.get_mut(&blob_id) {
tracing::info!(
"subnet {} for blob id {} has been successfully sampled",
blob.column_idx,
hex::encode(blob_id)
);
ctx.subnets.insert(blob.column_idx as SubnetworkId);
// sampling of this blob_id terminated successfully
if ctx.subnets.len() == self.settings.num_samples as usize {
self.validated_blobs.insert(blob_id);
tracing::info!(
"blob_id {} has been successfully sampled",
hex::encode(blob_id)
);
// cleanup from pending samplings
self.pending_sampling_blobs.remove(&blob_id);
}
}
}
async fn handle_sampling_error(&mut self, blob_id: Self::BlobId) {
// If it fails a single time we consider it failed.
// We may want to abstract the sampling policies somewhere else at some point if we
// need to get fancier than this
self.pending_sampling_blobs.remove(&blob_id);
self.validated_blobs.remove(&blob_id);
}
async fn init_sampling(&mut self, blob_id: Self::BlobId) -> SamplingState {
if self.pending_sampling_blobs.contains_key(&blob_id) {
return SamplingState::Tracking;
}
if self.validated_blobs.contains(&blob_id) {
return SamplingState::Terminated;
}
let subnets: Vec<SubnetworkId> = Standard
.sample_iter(&mut self.rng)
.take(self.settings.num_samples as usize)
.collect();
let ctx: SamplingContext = SamplingContext {
subnets: HashSet::new(),
started: Instant::now(),
};
self.pending_sampling_blobs.insert(blob_id, ctx);
SamplingState::Init(subnets)
}
fn prune(&mut self) {
self.prune_by_time()
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::time::{Duration, Instant};
use rand::prelude::*;
use rand::rngs::StdRng;
use crate::backend::kzgrs::{
DaSamplingServiceBackend, KzgrsDaSampler, KzgrsDaSamplerSettings, SamplingContext,
SamplingState,
};
use kzgrs_backend::common::{blob::DaBlob, Column};
use nomos_core::da::BlobId;
fn create_sampler(subnet_num: usize) -> KzgrsDaSampler<StdRng> {
let settings = KzgrsDaSamplerSettings {
num_samples: subnet_num as u16,
old_blobs_check_interval: Duration::from_millis(20),
blobs_validity_duration: Duration::from_millis(10),
};
let rng = StdRng::from_entropy();
KzgrsDaSampler::new(settings, rng)
}
#[tokio::test]
async fn test_sampler() {
// fictitious number of subnets
let subnet_num: usize = 42;
// create a sampler instance
let sampler = &mut create_sampler(subnet_num);
// create some blobs and blob_ids
let b1: BlobId = sampler.rng.gen();
let b2: BlobId = sampler.rng.gen();
let blob = DaBlob {
column_idx: 42,
column: Column(vec![]),
column_commitment: Default::default(),
aggregated_column_commitment: Default::default(),
aggregated_column_proof: Default::default(),
rows_commitments: vec![],
rows_proofs: vec![],
};
let blob2 = blob.clone();
let mut blob3 = blob2.clone();
// at start everything should be empty
assert!(sampler.pending_sampling_blobs.is_empty());
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.get_validated_blobs().await.is_empty());
// start sampling for b1
let SamplingState::Init(subnets_to_sample) = sampler.init_sampling(b1).await else {
panic!("unexpected return value")
};
assert!(subnets_to_sample.len() == subnet_num);
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.len() == 1);
// start sampling for b2
let SamplingState::Init(subnets_to_sample2) = sampler.init_sampling(b2).await else {
panic!("unexpected return value")
};
assert!(subnets_to_sample2.len() == subnet_num);
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.len() == 2);
// mark in block for both
// collections should be reset
sampler.mark_in_block(&[b1, b2]).await;
assert!(sampler.pending_sampling_blobs.is_empty());
assert!(sampler.validated_blobs.is_empty());
// because they're reset, we need to restart sampling for the test
_ = sampler.init_sampling(b1).await;
_ = sampler.init_sampling(b2).await;
// handle ficticious error for b2
// b2 should be gone, b1 still around
sampler.handle_sampling_error(b2).await;
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.len() == 1);
assert!(sampler.pending_sampling_blobs.contains_key(&b1));
// handle ficticious sampling success for b1
// should still just have one pending blob, no validated blobs yet,
// and one subnet added to blob
sampler.handle_sampling_success(b1, blob).await;
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.len() == 1);
println!(
"{}",
sampler
.pending_sampling_blobs
.get(&b1)
.unwrap()
.subnets
.len()
);
assert!(
sampler
.pending_sampling_blobs
.get(&b1)
.unwrap()
.subnets
.len()
== 1
);
// handle_success for always the same subnet
// by adding number of subnets time the same subnet
// (column_idx did not change)
// should not change, still just one sampling blob,
// no validated blobs and one subnet
for _ in 1..subnet_num {
let b = blob2.clone();
sampler.handle_sampling_success(b1, b).await;
}
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.len() == 1);
assert!(
sampler
.pending_sampling_blobs
.get(&b1)
.unwrap()
.subnets
.len()
== 1
);
// handle_success for up to subnet size minus one subnet
// should still not change anything
// but subnets len is now subnet size minus one
// we already added subnet 42
for i in 1..(subnet_num - 1) {
let mut b = blob2.clone();
b.column_idx = i as u16;
sampler.handle_sampling_success(b1, b).await;
}
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.len() == 1);
assert!(
sampler
.pending_sampling_blobs
.get(&b1)
.unwrap()
.subnets
.len()
== subnet_num - 1
);
// now add the last subnet!
// we should have all subnets set,
// and the validated blobs should now have that blob
// pending blobs should now be empty
blob3.column_idx = (subnet_num - 1) as u16;
sampler.handle_sampling_success(b1, blob3).await;
assert!(sampler.pending_sampling_blobs.is_empty());
assert!(sampler.validated_blobs.len() == 1);
assert!(sampler.validated_blobs.contains(&b1));
// these checks are redundant but better safe than sorry
assert!(sampler.get_validated_blobs().await.len() == 1);
assert!(sampler.get_validated_blobs().await.contains(&b1));
// run mark_in_block for the same blob
// should return empty for everything
sampler.mark_in_block(&[b1]).await;
assert!(sampler.validated_blobs.is_empty());
assert!(sampler.pending_sampling_blobs.is_empty());
}
#[tokio::test]
async fn test_pruning() {
let mut sampler = create_sampler(42);
// create some sampling contexes
// first set will go through as in time
let ctx1 = SamplingContext {
subnets: HashSet::new(),
started: Instant::now(),
};
let ctx2 = ctx1.clone();
let ctx3 = ctx1.clone();
// second set: will fail for expired
let ctx11 = SamplingContext {
subnets: HashSet::new(),
started: Instant::now() - Duration::from_secs(1),
};
let ctx12 = ctx11.clone();
let ctx13 = ctx11.clone();
// create a couple blob ids
let b1: BlobId = sampler.rng.gen();
let b2: BlobId = sampler.rng.gen();
let b3: BlobId = sampler.rng.gen();
// insert first blob
// pruning should have no effect
assert!(sampler.pending_sampling_blobs.is_empty());
sampler.pending_sampling_blobs.insert(b1, ctx1);
sampler.prune();
assert!(sampler.pending_sampling_blobs.len() == 1);
// insert second blob
// pruning should have no effect
sampler.pending_sampling_blobs.insert(b2, ctx2);
sampler.prune();
assert!(sampler.pending_sampling_blobs.len() == 2);
// insert third blob
// pruning should have no effect
sampler.pending_sampling_blobs.insert(b3, ctx3);
sampler.prune();
assert!(sampler.pending_sampling_blobs.len() == 3);
// insert fake expired blobs
// pruning these should now decrease pending blobx every time
sampler.pending_sampling_blobs.insert(b1, ctx11);
sampler.prune();
assert!(sampler.pending_sampling_blobs.len() == 2);
sampler.pending_sampling_blobs.insert(b2, ctx12);
sampler.prune();
assert!(sampler.pending_sampling_blobs.len() == 1);
sampler.pending_sampling_blobs.insert(b3, ctx13);
sampler.prune();
assert!(sampler.pending_sampling_blobs.is_empty());
}
}

View File

@ -1,21 +1,33 @@
pub mod kzgrs;
// std // std
use std::collections::BTreeSet; use std::collections::BTreeSet;
// crates // crates
use rand::Rng;
use tokio::time::Interval;
// //
// internal // internal
use nomos_da_network_core::SubnetworkId; use nomos_da_network_core::SubnetworkId;
pub enum SamplingState {
Init(Vec<SubnetworkId>),
Tracking,
Terminated,
}
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait DaSamplingServiceBackend { pub trait DaSamplingServiceBackend<R: Rng> {
type Settings; type Settings;
type BlobId; type BlobId;
type Blob; type Blob;
fn new(settings: Self::Settings) -> Self; fn new(settings: Self::Settings, rng: R) -> Self;
async fn get_validated_blobs(&self) -> BTreeSet<Self::BlobId>; async fn get_validated_blobs(&self) -> BTreeSet<Self::BlobId>;
async fn mark_in_block(&mut self, blobs_id: &[Self::BlobId]); async fn mark_in_block(&mut self, blobs_ids: &[Self::BlobId]);
async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob); async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob);
async fn handle_sampling_error(&mut self, blob_id: Self::BlobId); async fn handle_sampling_error(&mut self, blob_id: Self::BlobId);
async fn init_sampling(&mut self, blob_id: Self::BlobId) -> Vec<SubnetworkId>; async fn init_sampling(&mut self, blob_id: Self::BlobId) -> SamplingState;
async fn prune_interval(&self) -> Interval;
fn prune(&mut self);
} }

View File

@ -6,10 +6,11 @@ use std::collections::BTreeSet;
use std::fmt::Debug; use std::fmt::Debug;
// crates // crates
use rand::prelude::*;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{error, span, Instrument, Level}; use tracing::{error, span, Instrument, Level};
// internal // internal
use backend::DaSamplingServiceBackend; use backend::{DaSamplingServiceBackend, SamplingState};
use kzgrs_backend::common::blob::DaBlob; use kzgrs_backend::common::blob::DaBlob;
use network::NetworkAdapter; use network::NetworkAdapter;
use nomos_core::da::BlobId; use nomos_core::da::BlobId;
@ -46,9 +47,10 @@ pub struct DaSamplingServiceSettings<BackendSettings, NetworkSettings> {
impl<B: 'static> RelayMessage for DaSamplingServiceMsg<B> {} impl<B: 'static> RelayMessage for DaSamplingServiceMsg<B> {}
pub struct DaSamplingService<Backend, N, S> pub struct DaSamplingService<Backend, N, S, R>
where where
Backend: DaSamplingServiceBackend + Send, R: SeedableRng + RngCore,
Backend: DaSamplingServiceBackend<R> + Send,
Backend::Settings: Clone, Backend::Settings: Clone,
Backend::Blob: Debug + 'static, Backend::Blob: Debug + 'static,
Backend::BlobId: Debug + 'static, Backend::BlobId: Debug + 'static,
@ -60,9 +62,10 @@ where
sampler: Backend, sampler: Backend,
} }
impl<Backend, N, S> DaSamplingService<Backend, N, S> impl<Backend, N, S, R> DaSamplingService<Backend, N, S, R>
where where
Backend: DaSamplingServiceBackend<BlobId = BlobId, Blob = DaBlob> + Send + 'static, R: SeedableRng + RngCore,
Backend: DaSamplingServiceBackend<R, BlobId = BlobId, Blob = DaBlob> + Send + 'static,
Backend::Settings: Clone, Backend::Settings: Clone,
N: NetworkAdapter + Send + 'static, N: NetworkAdapter + Send + 'static,
N::Settings: Clone, N::Settings: Clone,
@ -89,14 +92,18 @@ where
) { ) {
match msg { match msg {
DaSamplingServiceMsg::TriggerSampling { blob_id } => { DaSamplingServiceMsg::TriggerSampling { blob_id } => {
let sampling_subnets = sampler.init_sampling(blob_id).await; if let SamplingState::Init(sampling_subnets) = sampler.init_sampling(blob_id).await
{
if let Err(e) = network_adapter if let Err(e) = network_adapter
.start_sampling(blob_id, &sampling_subnets) .start_sampling(blob_id, &sampling_subnets)
.await .await
{ {
// we can short circuit the failure from beginning
sampler.handle_sampling_error(blob_id).await;
error!("Error sampling for BlobId: {blob_id:?}: {e}"); error!("Error sampling for BlobId: {blob_id:?}: {e}");
} }
} }
}
DaSamplingServiceMsg::GetValidatedBlobs { reply_channel } => { DaSamplingServiceMsg::GetValidatedBlobs { reply_channel } => {
let validated_blobs = sampler.get_validated_blobs().await; let validated_blobs = sampler.get_validated_blobs().await;
if let Err(_e) = reply_channel.send(validated_blobs) { if let Err(_e) = reply_channel.send(validated_blobs) {
@ -115,15 +122,20 @@ where
sampler.handle_sampling_success(blob_id, *blob).await; sampler.handle_sampling_success(blob_id, *blob).await;
} }
SamplingEvent::SamplingError { error } => { SamplingEvent::SamplingError { error } => {
if let Some(blob_id) = error.blob_id() {
sampler.handle_sampling_error(*blob_id).await;
return;
}
error!("Error while sampling: {error}"); error!("Error while sampling: {error}");
} }
} }
} }
} }
impl<Backend, N, S> ServiceData for DaSamplingService<Backend, N, S> impl<Backend, N, S, R> ServiceData for DaSamplingService<Backend, N, S, R>
where where
Backend: DaSamplingServiceBackend + Send, R: SeedableRng + RngCore,
Backend: DaSamplingServiceBackend<R> + Send,
Backend::Settings: Clone, Backend::Settings: Clone,
Backend::Blob: Debug + 'static, Backend::Blob: Debug + 'static,
Backend::BlobId: Debug + 'static, Backend::BlobId: Debug + 'static,
@ -138,9 +150,10 @@ where
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl<Backend, N, S> ServiceCore for DaSamplingService<Backend, N, S> impl<Backend, N, S, R> ServiceCore for DaSamplingService<Backend, N, S, R>
where where
Backend: DaSamplingServiceBackend<BlobId = BlobId, Blob = DaBlob> + Send + Sync + 'static, R: SeedableRng + RngCore,
Backend: DaSamplingServiceBackend<R, BlobId = BlobId, Blob = DaBlob> + Send + Sync + 'static,
Backend::Settings: Clone + Send + Sync + 'static, Backend::Settings: Clone + Send + Sync + 'static,
N: NetworkAdapter + Send + Sync + 'static, N: NetworkAdapter + Send + Sync + 'static,
N::Settings: Clone + Send + Sync + 'static, N::Settings: Clone + Send + Sync + 'static,
@ -151,11 +164,12 @@ where
} = service_state.settings_reader.get_updated_settings(); } = service_state.settings_reader.get_updated_settings();
let network_relay = service_state.overwatch_handle.relay(); let network_relay = service_state.overwatch_handle.relay();
let rng = R::from_entropy();
Ok(Self { Ok(Self {
network_relay, network_relay,
service_state, service_state,
sampler: Backend::new(sampling_settings), sampler: Backend::new(sampling_settings, rng),
}) })
} }
@ -175,6 +189,7 @@ where
let mut network_adapter = N::new(network_relay).await; let mut network_adapter = N::new(network_relay).await;
let mut sampling_message_stream = network_adapter.listen_to_sampling_messages().await?; let mut sampling_message_stream = network_adapter.listen_to_sampling_messages().await?;
let mut next_prune_tick = sampler.prune_interval().await;
let mut lifecycle_stream = service_state.lifecycle_handle.message_stream(); let mut lifecycle_stream = service_state.lifecycle_handle.message_stream();
async { async {
@ -191,6 +206,11 @@ where
break; break;
} }
} }
// cleanup not on time samples
_ = next_prune_tick.tick() => {
sampler.prune();
}
} }
} }
} }