Fix missing generics on impl block (#26)

* fix missing generics on impl block
This commit is contained in:
Al Liu 2023-10-05 22:37:54 +08:00 committed by GitHub
parent 9b865c3ece
commit 1d36a024ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 156 additions and 9 deletions

View File

@ -2,7 +2,7 @@ mod utils;
use proc_macro_error::{abort_call_site, proc_macro_error}; use proc_macro_error::{abort_call_site, proc_macro_error};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{punctuated::Punctuated, token::Comma, Data, DeriveInput, Field}; use syn::{punctuated::Punctuated, token::Comma, Data, DeriveInput, Field, Generics};
#[proc_macro_derive(Services)] #[proc_macro_derive(Services)]
#[proc_macro_error] #[proc_macro_error]
@ -29,11 +29,12 @@ fn impl_services(input: &DeriveInput) -> proc_macro2::TokenStream {
let struct_identifier = &input.ident; let struct_identifier = &input.ident;
let data = &input.data; let data = &input.data;
let generics = &input.generics;
match data { match data {
Data::Struct(DataStruct { Data::Struct(DataStruct {
fields: syn::Fields::Named(fields), fields: syn::Fields::Named(fields),
.. ..
}) => impl_services_for_struct(struct_identifier, &fields.named), }) => impl_services_for_struct(struct_identifier, generics, &fields.named),
_ => { _ => {
abort_call_site!("Deriving Services is only supported for named Structs"); abort_call_site!("Deriving Services is only supported for named Structs");
} }
@ -42,11 +43,12 @@ fn impl_services(input: &DeriveInput) -> proc_macro2::TokenStream {
fn impl_services_for_struct( fn impl_services_for_struct(
identifier: &proc_macro2::Ident, identifier: &proc_macro2::Ident,
generics: &Generics,
fields: &Punctuated<Field, Comma>, fields: &Punctuated<Field, Comma>,
) -> proc_macro2::TokenStream { ) -> proc_macro2::TokenStream {
let settings = generate_services_settings(identifier, fields); let settings = generate_services_settings(identifier, generics, fields);
let unique_ids_check = generate_assert_unique_identifiers(identifier, fields); let unique_ids_check = generate_assert_unique_identifiers(identifier, generics, fields);
let services_impl = generate_services_impl(identifier, fields); let services_impl = generate_services_impl(identifier, generics, fields);
quote! { quote! {
#unique_ids_check #unique_ids_check
@ -59,6 +61,7 @@ fn impl_services_for_struct(
fn generate_services_settings( fn generate_services_settings(
services_identifier: &proc_macro2::Ident, services_identifier: &proc_macro2::Ident,
generics: &Generics,
fields: &Punctuated<Field, Comma>, fields: &Punctuated<Field, Comma>,
) -> proc_macro2::TokenStream { ) -> proc_macro2::TokenStream {
let services_settings = fields.iter().map(|field| { let services_settings = fields.iter().map(|field| {
@ -68,9 +71,10 @@ fn generate_services_settings(
quote!(pub #service_name: <#_type as ::overwatch_rs::services::ServiceData>::Settings) quote!(pub #service_name: <#_type as ::overwatch_rs::services::ServiceData>::Settings)
}); });
let services_settings_identifier = service_settings_identifier_from(services_identifier); let services_settings_identifier = service_settings_identifier_from(services_identifier);
let where_clause = &generics.where_clause;
quote! { quote! {
#[derive(::std::clone::Clone, ::std::fmt::Debug)] #[derive(::std::clone::Clone, ::std::fmt::Debug)]
pub struct #services_settings_identifier { pub struct #services_settings_identifier #generics #where_clause {
#( #services_settings ),* #( #services_settings ),*
} }
} }
@ -78,6 +82,7 @@ fn generate_services_settings(
fn generate_assert_unique_identifiers( fn generate_assert_unique_identifiers(
services_identifier: &proc_macro2::Ident, services_identifier: &proc_macro2::Ident,
generics: &Generics,
fields: &Punctuated<Field, Comma>, fields: &Punctuated<Field, Comma>,
) -> proc_macro2::TokenStream { ) -> proc_macro2::TokenStream {
let services_ids = fields.iter().map(|field| { let services_ids = fields.iter().map(|field| {
@ -90,14 +95,18 @@ fn generate_assert_unique_identifiers(
"__{}__CONST_CHECK_UNIQUE_SERVICES_IDS", "__{}__CONST_CHECK_UNIQUE_SERVICES_IDS",
services_identifier.to_string().to_uppercase() services_identifier.to_string().to_uppercase()
); );
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
quote! { quote! {
const #services_ids_check: () = assert!(::overwatch_rs::utils::const_checks::unique_ids(&[#( #services_ids ),*])); impl #impl_generics #services_identifier #ty_generics #where_clause {
const #services_ids_check: () = assert!(::overwatch_rs::utils::const_checks::unique_ids(&[#( #services_ids ),*]));
}
} }
} }
fn generate_services_impl( fn generate_services_impl(
services_identifier: &proc_macro2::Ident, services_identifier: &proc_macro2::Ident,
generics: &Generics,
fields: &Punctuated<Field, Comma>, fields: &Punctuated<Field, Comma>,
) -> proc_macro2::TokenStream { ) -> proc_macro2::TokenStream {
let services_settings_identifier = service_settings_identifier_from(services_identifier); let services_settings_identifier = service_settings_identifier_from(services_identifier);
@ -107,10 +116,11 @@ fn generate_services_impl(
let impl_stop = generate_stop_impl(fields); let impl_stop = generate_stop_impl(fields);
let impl_relay = generate_request_relay_impl(fields); let impl_relay = generate_request_relay_impl(fields);
let impl_update_settings = generate_update_settings_impl(fields); let impl_update_settings = generate_update_settings_impl(fields);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
quote! { quote! {
impl ::overwatch_rs::overwatch::Services for #services_identifier { impl #impl_generics ::overwatch_rs::overwatch::Services for #services_identifier #ty_generics #where_clause {
type Settings = #services_settings_identifier; type Settings = #services_settings_identifier #ty_generics;
#impl_new #impl_new

View File

@ -0,0 +1,137 @@
use async_trait::async_trait;
use futures::future::select;
use overwatch_derive::Services;
use overwatch_rs::overwatch::OverwatchRunner;
use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle};
use overwatch_rs::services::relay::RelayMessage;
use overwatch_rs::services::state::{NoOperator, NoState};
use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId};
use std::fmt::Debug;
use std::time::Duration;
use tokio::time::sleep;
pub struct GenericService<T: Send>
where
T: Debug + 'static + Sync,
{
state: ServiceStateHandle<Self>,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Clone, Debug)]
pub struct GenericServiceMessage(String);
impl RelayMessage for GenericServiceMessage {}
impl<T: Send> ServiceData for GenericService<T>
where
T: Debug + 'static + Sync,
{
const SERVICE_ID: ServiceId = "FooService";
type Settings = ();
type State = NoState<Self::Settings>;
type StateOperator = NoOperator<Self::State>;
type Message = GenericServiceMessage;
}
#[async_trait]
impl<T: Send> ServiceCore for GenericService<T>
where
T: Debug + 'static + Sync,
{
fn init(state: ServiceStateHandle<Self>) -> Result<Self, overwatch_rs::DynError> {
Ok(Self {
state,
_phantom: std::marker::PhantomData,
})
}
async fn run(mut self) -> Result<(), overwatch_rs::DynError> {
use tokio::io::{self, AsyncWriteExt};
let Self {
state: ServiceStateHandle {
mut inbound_relay, ..
},
..
} = self;
let generic = async move {
let mut stdout = io::stdout();
while let Some(message) = inbound_relay.recv().await {
match message.0.as_ref() {
"stop" => {
stdout
.write_all(b"genericing service stopping\n")
.await
.expect("stop Output wrote");
break;
}
m => {
stdout
.write_all(format!("{m}\n").as_bytes())
.await
.expect("Message output wrote");
}
}
}
};
let idle = async move {
let mut stdout = io::stdout();
loop {
stdout
.write_all(b"Waiting for generic process to finish...\n")
.await
.expect("Message output wrote");
sleep(Duration::from_millis(50)).await;
}
};
select(Box::pin(generic), Box::pin(idle)).await;
Ok(())
}
}
#[derive(Services)]
struct TestApp<T: Send>
where
T: Debug + 'static + Sync,
{
generic_service: ServiceHandle<GenericService<T>>,
}
#[test]
fn derive_generic_service() {
let settings: TestAppServiceSettings<String> = TestAppServiceSettings {
generic_service: (),
};
let overwatch = OverwatchRunner::<TestApp<String>>::run(settings, None).unwrap();
let handle = overwatch.handle().clone();
let generic_service_relay = handle.relay::<GenericService<String>>();
overwatch.spawn(async move {
let generic_service_relay = generic_service_relay
.connect()
.await
.expect("A connection to the generic service is established");
for _ in 0..3 {
generic_service_relay
.send(GenericServiceMessage("Hey oh let's go!".to_string()))
.await
.expect("Message is sent");
}
sleep(Duration::from_millis(50)).await;
generic_service_relay
.send(GenericServiceMessage("stop".to_string()))
.await
.expect("stop message to be sent");
});
overwatch.spawn(async move {
sleep(Duration::from_secs(1)).await;
handle.shutdown().await;
});
overwatch.wait_finished();
}