Fix kzgrs backend (#710)

This commit is contained in:
Daniel Sanchez 2024-08-29 16:16:50 +02:00 committed by GitHub
parent 9b13facb7a
commit 47d5e2cdc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 51 deletions

View File

@ -1,12 +1,9 @@
use std::borrow::BorrowMut;
// std
use std::collections::{BTreeSet, HashMap};
use std::fmt::Debug;
use std::thread;
use std::time::Duration;
use std::time::{Duration, Instant};
// crates
use chrono::{naive::NaiveDateTime, Utc};
use rand::distributions::Standard;
use rand::prelude::*;
use rand_chacha::ChaCha20Rng;
@ -20,9 +17,8 @@ use nomos_core::da::BlobId;
use nomos_da_network_core::SubnetworkId;
pub struct SamplingContext {
blob_id: BlobId,
subnets: Vec<SubnetworkId>,
started: NaiveDateTime,
started: Instant,
}
#[derive(Debug, Clone)]
@ -35,8 +31,6 @@ pub struct KzgrsDaSamplerSettings {
pub struct KzgrsDaSampler {
settings: KzgrsDaSamplerSettings,
validated_blobs: BTreeSet<BlobId>,
// TODO: This needs to be properly synchronized, if this is going to be accessed
// by independent threads (monitoring thread)
pending_sampling_blobs: HashMap<BlobId, SamplingContext>,
// TODO: is there a better place for this? Do we need to have this even globally?
// Do we already have some source of randomness already?
@ -44,25 +38,15 @@ pub struct KzgrsDaSampler {
}
impl KzgrsDaSampler {
// TODO: this might not be the right signature, as the lifetime of self needs to be evaluated
async fn start_pending_blob_monitor(&'static mut self) {
//let mut sself = self;
let monitor = thread::spawn(move || {
loop {
thread::sleep(self.settings.old_blobs_check_duration);
// everything older than cut_timestamp should be removed;
let cut_timestamp = Utc::now().naive_utc() - self.settings.blobs_validity_duration;
// retain all elements which come after the cut_timestamp
self.pending_sampling_blobs
.retain(|_, ctx| ctx.started.gt(&cut_timestamp));
}
fn prune_by_time(&mut self) {
self.pending_sampling_blobs.retain(|_blob_id, context| {
context.started.elapsed() < self.settings.old_blobs_check_duration
});
monitor.join().unwrap();
}
}
#[async_trait::async_trait]
impl<'a> DaSamplingServiceBackend for KzgrsDaSampler {
impl DaSamplingServiceBackend for KzgrsDaSampler {
type Settings = KzgrsDaSamplerSettings;
type BlobId = BlobId;
type Blob = DaBlob;
@ -70,12 +54,11 @@ impl<'a> DaSamplingServiceBackend for KzgrsDaSampler {
fn new(settings: Self::Settings) -> Self {
let bt: BTreeSet<BlobId> = BTreeSet::new();
Self {
settings: settings,
settings,
validated_blobs: bt,
pending_sampling_blobs: HashMap::new(),
rng: ChaCha20Rng::from_entropy(),
}
// TODO: how to start the actual monitoring thread with the correct ownership/lifetime?
}
async fn get_validated_blobs(&self) -> BTreeSet<Self::BlobId> {
@ -84,48 +67,48 @@ impl<'a> DaSamplingServiceBackend for KzgrsDaSampler {
async fn mark_in_block(&mut self, blobs_ids: &[Self::BlobId]) {
for id in blobs_ids {
if self.pending_sampling_blobs.contains_key(id) {
self.pending_sampling_blobs.remove(id);
}
if self.validated_blobs.contains(id) {
self.validated_blobs.remove(id);
}
self.pending_sampling_blobs.remove(id);
self.validated_blobs.remove(id);
}
}
async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob) {
// this should not even happen
if !self.pending_sampling_blobs.contains_key(&blob_id) {}
if let Some(ctx) = self.pending_sampling_blobs.get_mut(&blob_id) {
ctx.subnets.push(blob.column_idx as SubnetworkId);
let ctx = self.pending_sampling_blobs.get_mut(&blob_id).unwrap();
ctx.subnets.push(blob.column_idx as SubnetworkId);
// sampling of this blob_id terminated successfully
if ctx.subnets.len() == self.settings.num_samples as usize {
self.validated_blobs.insert(blob_id);
// sampling of this blob_id terminated successfully
if ctx.subnets.len() == self.settings.num_samples as usize {
self.validated_blobs.insert(blob_id);
// cleanup from pending samplings
self.pending_sampling_blobs.remove(&blob_id);
}
} else {
unreachable!("We should not receive a sampling success from a non triggered blobId");
}
}
async fn handle_sampling_error(&mut self, _blob_id: Self::BlobId) {
// TODO: Unimplmented yet because the error handling in the service
// does not yet receive a blob_id
unimplemented!("no use case yet")
async fn handle_sampling_error(&mut self, blob_id: Self::BlobId) {
// If it fails a single time we consider it failed.
// We may want to abstract the sampling policies somewhere else at some point if we
// need to get fancier than this
self.pending_sampling_blobs.remove(&blob_id);
self.validated_blobs.remove(&blob_id);
}
async fn init_sampling(&mut self, blob_id: Self::BlobId) -> Vec<SubnetworkId> {
let mut ctx: SamplingContext = SamplingContext {
blob_id: (blob_id),
subnets: vec![],
started: Utc::now().naive_utc(),
};
let subnets: Vec<SubnetworkId> = Standard
.sample_iter(&mut self.rng)
.take(self.settings.num_samples as usize)
.collect();
ctx.subnets = subnets.clone();
let ctx: SamplingContext = SamplingContext {
subnets: subnets.clone(),
started: Instant::now(),
};
self.pending_sampling_blobs.insert(blob_id, ctx);
subnets
}
fn prune(&mut self) {
self.prune_by_time()
}
}

View File

@ -20,4 +20,5 @@ pub trait DaSamplingServiceBackend {
async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob);
async fn handle_sampling_error(&mut self, blob_id: Self::BlobId);
async fn init_sampling(&mut self, blob_id: Self::BlobId) -> Vec<SubnetworkId>;
fn prune(&mut self);
}

View File

@ -197,6 +197,8 @@ where
}
}
}
// cleanup not on time samples
sampler.prune();
}
}
.instrument(span!(Level::TRACE, DA_SAMPLING_TAG))