abstract rng

This commit is contained in:
holisticode 2024-08-29 12:44:07 -05:00
parent 47d5e2cdc8
commit b3c5379b6a
3 changed files with 33 additions and 20 deletions

View File

@ -6,7 +6,7 @@ use std::time::{Duration, Instant};
// crates
use rand::distributions::Standard;
use rand::prelude::*;
use rand_chacha::ChaCha20Rng;
use rand::SeedableRng;
use kzgrs_backend::common::blob::DaBlob;
@ -28,16 +28,16 @@ pub struct KzgrsDaSamplerSettings {
pub blobs_validity_duration: Duration,
}
pub struct KzgrsDaSampler {
pub struct KzgrsDaSampler<R: Rng> {
settings: KzgrsDaSamplerSettings,
validated_blobs: BTreeSet<BlobId>,
pending_sampling_blobs: HashMap<BlobId, SamplingContext>,
// TODO: is there a better place for this? Do we need to have this even globally?
// Do we already have some source of randomness already?
rng: ChaCha20Rng,
rng: R,
}
impl KzgrsDaSampler {
impl<R: Rng> KzgrsDaSampler<R> {
fn prune_by_time(&mut self) {
self.pending_sampling_blobs.retain(|_blob_id, context| {
context.started.elapsed() < self.settings.old_blobs_check_duration
@ -46,18 +46,18 @@ impl KzgrsDaSampler {
}
#[async_trait::async_trait]
impl DaSamplingServiceBackend for KzgrsDaSampler {
impl<R: Rng + Sync + Send> DaSamplingServiceBackend<R> for KzgrsDaSampler<R> {
type Settings = KzgrsDaSamplerSettings;
type BlobId = BlobId;
type Blob = DaBlob;
fn new(settings: Self::Settings) -> Self {
fn new(settings: Self::Settings, rng: R) -> Self {
let bt: BTreeSet<BlobId> = BTreeSet::new();
Self {
settings,
validated_blobs: bt,
pending_sampling_blobs: HashMap::new(),
rng: ChaCha20Rng::from_entropy(),
rng: rng,
}
}

View File

@ -4,17 +4,18 @@ pub mod kzgrs;
use std::collections::BTreeSet;
// crates
use rand::Rng;
//
// internal
use nomos_da_network_core::SubnetworkId;
#[async_trait::async_trait]
pub trait DaSamplingServiceBackend {
pub trait DaSamplingServiceBackend<R: Rng> {
type Settings;
type BlobId;
type Blob;
fn new(settings: Self::Settings) -> Self;
fn new(settings: Self::Settings, rng: R) -> Self;
async fn get_validated_blobs(&self) -> BTreeSet<Self::BlobId>;
async fn mark_in_block(&mut self, blobs_ids: &[Self::BlobId]);
async fn handle_sampling_success(&mut self, blob_id: Self::BlobId, blob: Self::Blob);

View File

@ -6,6 +6,10 @@ use std::collections::BTreeSet;
use std::fmt::Debug;
// crates
use rand::prelude::*;
use rand::Rng;
use rand_chacha::rand_core::CryptoRngCore;
use rand_chacha::ChaCha20Rng;
use tokio_stream::StreamExt;
use tracing::{error, span, Instrument, Level};
// internal
@ -46,9 +50,10 @@ pub struct DaSamplingServiceSettings<BackendSettings, NetworkSettings> {
impl<B: 'static> RelayMessage for DaSamplingServiceMsg<B> {}
pub struct DaSamplingService<Backend, N, S>
pub struct DaSamplingService<Backend, N, S, R>
where
Backend: DaSamplingServiceBackend + Send,
R: Rng,
Backend: DaSamplingServiceBackend<R> + Send,
Backend::Settings: Clone,
Backend::Blob: Debug + 'static,
Backend::BlobId: Debug + 'static,
@ -60,9 +65,10 @@ where
sampler: Backend,
}
impl<Backend, N, S> DaSamplingService<Backend, N, S>
impl<Backend, N, S, R> DaSamplingService<Backend, N, S, R>
where
Backend: DaSamplingServiceBackend<BlobId = BlobId, Blob = DaBlob> + Send + 'static,
R: Rng,
Backend: DaSamplingServiceBackend<R, BlobId = BlobId, Blob = DaBlob> + Send + 'static,
Backend::Settings: Clone,
N: NetworkAdapter + Send + 'static,
N::Settings: Clone,
@ -126,9 +132,10 @@ where
}
}
impl<Backend, N, S> ServiceData for DaSamplingService<Backend, N, S>
impl<Backend, N, S, R> ServiceData for DaSamplingService<Backend, N, S, R>
where
Backend: DaSamplingServiceBackend + Send,
R: Rng,
Backend: DaSamplingServiceBackend<R> + Send,
Backend::Settings: Clone,
Backend::Blob: Debug + 'static,
Backend::BlobId: Debug + 'static,
@ -143,9 +150,10 @@ where
}
#[async_trait::async_trait]
impl<Backend, N, S> ServiceCore for DaSamplingService<Backend, N, S>
impl<Backend, N, S, R> ServiceCore for DaSamplingService<Backend, N, S, R>
where
Backend: DaSamplingServiceBackend<BlobId = BlobId, Blob = DaBlob> + Send + Sync + 'static,
R: Rng,
Backend: DaSamplingServiceBackend<R, BlobId = BlobId, Blob = DaBlob> + Send + Sync + 'static,
Backend::Settings: Clone + Send + Sync + 'static,
N: NetworkAdapter + Send + Sync + 'static,
N::Settings: Clone + Send + Sync + 'static,
@ -156,11 +164,12 @@ where
} = service_state.settings_reader.get_updated_settings();
let network_relay = service_state.overwatch_handle.relay();
let rng: R = ChaCha20Rng::from_entropy();
Ok(Self {
network_relay,
service_state,
sampler: Backend::new(sampling_settings),
sampler: Backend::new(sampling_settings, rng),
})
}
@ -187,18 +196,21 @@ where
tokio::select! {
Some(service_message) = service_state.inbound_relay.recv() => {
Self::handle_service_message(service_message, &mut network_adapter, &mut sampler).await;
// cleanup not on time samples
sampler.prune();
}
Some(sampling_message) = sampling_message_stream.next() => {
Self::handle_sampling_message(sampling_message, &mut sampler).await;
// cleanup not on time samples
sampler.prune();
}
Some(msg) = lifecycle_stream.next() => {
if Self::should_stop_service(msg).await {
break;
}
}
}
// cleanup not on time samples
sampler.prune();
}
}
.instrument(span!(Level::TRACE, DA_SAMPLING_TAG))