From 1d36a024efc3de3933364ddacc31b389f96d75b3 Mon Sep 17 00:00:00 2001 From: Al Liu Date: Thu, 5 Oct 2023 22:37:54 +0800 Subject: [PATCH] Fix missing generics on impl block (#26) * fix missing generics on impl block --- overwatch-derive/src/lib.rs | 28 ++++--- overwatch-rs/tests/generics.rs | 137 +++++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 9 deletions(-) create mode 100644 overwatch-rs/tests/generics.rs diff --git a/overwatch-derive/src/lib.rs b/overwatch-derive/src/lib.rs index 5b6ccc2..3858db0 100644 --- a/overwatch-derive/src/lib.rs +++ b/overwatch-derive/src/lib.rs @@ -2,7 +2,7 @@ mod utils; use proc_macro_error::{abort_call_site, proc_macro_error}; use quote::{format_ident, quote}; -use syn::{punctuated::Punctuated, token::Comma, Data, DeriveInput, Field}; +use syn::{punctuated::Punctuated, token::Comma, Data, DeriveInput, Field, Generics}; #[proc_macro_derive(Services)] #[proc_macro_error] @@ -29,11 +29,12 @@ fn impl_services(input: &DeriveInput) -> proc_macro2::TokenStream { let struct_identifier = &input.ident; let data = &input.data; + let generics = &input.generics; match data { Data::Struct(DataStruct { fields: syn::Fields::Named(fields), .. - }) => impl_services_for_struct(struct_identifier, &fields.named), + }) => impl_services_for_struct(struct_identifier, generics, &fields.named), _ => { abort_call_site!("Deriving Services is only supported for named Structs"); } @@ -42,11 +43,12 @@ fn impl_services(input: &DeriveInput) -> proc_macro2::TokenStream { fn impl_services_for_struct( identifier: &proc_macro2::Ident, + generics: &Generics, fields: &Punctuated, ) -> proc_macro2::TokenStream { - let settings = generate_services_settings(identifier, fields); - let unique_ids_check = generate_assert_unique_identifiers(identifier, fields); - let services_impl = generate_services_impl(identifier, fields); + let settings = generate_services_settings(identifier, generics, fields); + let unique_ids_check = generate_assert_unique_identifiers(identifier, generics, fields); + let services_impl = generate_services_impl(identifier, generics, fields); quote! { #unique_ids_check @@ -59,6 +61,7 @@ fn impl_services_for_struct( fn generate_services_settings( services_identifier: &proc_macro2::Ident, + generics: &Generics, fields: &Punctuated, ) -> proc_macro2::TokenStream { let services_settings = fields.iter().map(|field| { @@ -68,9 +71,10 @@ fn generate_services_settings( quote!(pub #service_name: <#_type as ::overwatch_rs::services::ServiceData>::Settings) }); let services_settings_identifier = service_settings_identifier_from(services_identifier); + let where_clause = &generics.where_clause; quote! { #[derive(::std::clone::Clone, ::std::fmt::Debug)] - pub struct #services_settings_identifier { + pub struct #services_settings_identifier #generics #where_clause { #( #services_settings ),* } } @@ -78,6 +82,7 @@ fn generate_services_settings( fn generate_assert_unique_identifiers( services_identifier: &proc_macro2::Ident, + generics: &Generics, fields: &Punctuated, ) -> proc_macro2::TokenStream { let services_ids = fields.iter().map(|field| { @@ -90,14 +95,18 @@ fn generate_assert_unique_identifiers( "__{}__CONST_CHECK_UNIQUE_SERVICES_IDS", services_identifier.to_string().to_uppercase() ); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { - const #services_ids_check: () = assert!(::overwatch_rs::utils::const_checks::unique_ids(&[#( #services_ids ),*])); + impl #impl_generics #services_identifier #ty_generics #where_clause { + const #services_ids_check: () = assert!(::overwatch_rs::utils::const_checks::unique_ids(&[#( #services_ids ),*])); + } } } fn generate_services_impl( services_identifier: &proc_macro2::Ident, + generics: &Generics, fields: &Punctuated, ) -> proc_macro2::TokenStream { let services_settings_identifier = service_settings_identifier_from(services_identifier); @@ -107,10 +116,11 @@ fn generate_services_impl( let impl_stop = generate_stop_impl(fields); let impl_relay = generate_request_relay_impl(fields); let impl_update_settings = generate_update_settings_impl(fields); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { - impl ::overwatch_rs::overwatch::Services for #services_identifier { - type Settings = #services_settings_identifier; + impl #impl_generics ::overwatch_rs::overwatch::Services for #services_identifier #ty_generics #where_clause { + type Settings = #services_settings_identifier #ty_generics; #impl_new diff --git a/overwatch-rs/tests/generics.rs b/overwatch-rs/tests/generics.rs new file mode 100644 index 0000000..9a88406 --- /dev/null +++ b/overwatch-rs/tests/generics.rs @@ -0,0 +1,137 @@ +use async_trait::async_trait; +use futures::future::select; +use overwatch_derive::Services; +use overwatch_rs::overwatch::OverwatchRunner; +use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle}; +use overwatch_rs::services::relay::RelayMessage; +use overwatch_rs::services::state::{NoOperator, NoState}; +use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId}; +use std::fmt::Debug; +use std::time::Duration; +use tokio::time::sleep; + +pub struct GenericService +where + T: Debug + 'static + Sync, +{ + state: ServiceStateHandle, + _phantom: std::marker::PhantomData, +} + +#[derive(Clone, Debug)] +pub struct GenericServiceMessage(String); + +impl RelayMessage for GenericServiceMessage {} + +impl ServiceData for GenericService +where + T: Debug + 'static + Sync, +{ + const SERVICE_ID: ServiceId = "FooService"; + type Settings = (); + type State = NoState; + type StateOperator = NoOperator; + type Message = GenericServiceMessage; +} + +#[async_trait] +impl ServiceCore for GenericService +where + T: Debug + 'static + Sync, +{ + fn init(state: ServiceStateHandle) -> Result { + Ok(Self { + state, + _phantom: std::marker::PhantomData, + }) + } + + async fn run(mut self) -> Result<(), overwatch_rs::DynError> { + use tokio::io::{self, AsyncWriteExt}; + + let Self { + state: ServiceStateHandle { + mut inbound_relay, .. + }, + .. + } = self; + + let generic = async move { + let mut stdout = io::stdout(); + while let Some(message) = inbound_relay.recv().await { + match message.0.as_ref() { + "stop" => { + stdout + .write_all(b"genericing service stopping\n") + .await + .expect("stop Output wrote"); + break; + } + m => { + stdout + .write_all(format!("{m}\n").as_bytes()) + .await + .expect("Message output wrote"); + } + } + } + }; + + let idle = async move { + let mut stdout = io::stdout(); + loop { + stdout + .write_all(b"Waiting for generic process to finish...\n") + .await + .expect("Message output wrote"); + sleep(Duration::from_millis(50)).await; + } + }; + + select(Box::pin(generic), Box::pin(idle)).await; + Ok(()) + } +} + +#[derive(Services)] +struct TestApp +where + T: Debug + 'static + Sync, +{ + generic_service: ServiceHandle>, +} + +#[test] +fn derive_generic_service() { + let settings: TestAppServiceSettings = TestAppServiceSettings { + generic_service: (), + }; + let overwatch = OverwatchRunner::>::run(settings, None).unwrap(); + let handle = overwatch.handle().clone(); + let generic_service_relay = handle.relay::>(); + + overwatch.spawn(async move { + let generic_service_relay = generic_service_relay + .connect() + .await + .expect("A connection to the generic service is established"); + + for _ in 0..3 { + generic_service_relay + .send(GenericServiceMessage("Hey oh let's go!".to_string())) + .await + .expect("Message is sent"); + } + sleep(Duration::from_millis(50)).await; + generic_service_relay + .send(GenericServiceMessage("stop".to_string())) + .await + .expect("stop message to be sent"); + }); + + overwatch.spawn(async move { + sleep(Duration::from_secs(1)).await; + handle.shutdown().await; + }); + overwatch.wait_finished(); +}