From 73bc2b691a397f8768253e63cb65bae246385aec Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Sat, 7 Feb 2026 15:57:57 +0100 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=8E=A8=20core:=20Re-arrange=20attribu?= =?UTF-8?q?tes=20to=20satisfy=20rust-analyzer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For some strange reason, RA complains about the order. 🤷 --- zlink-core/src/connection/chain/mod.rs | 2 +- zlink-core/src/connection/tests/connection_tests.rs | 2 +- zlink-core/src/varlink_service/proxy.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/zlink-core/src/connection/chain/mod.rs b/zlink-core/src/connection/chain/mod.rs index 4d754d81..e1fb044b 100644 --- a/zlink-core/src/connection/chain/mod.rs +++ b/zlink-core/src/connection/chain/mod.rs @@ -590,8 +590,8 @@ mod tests { Ok(()) } - #[cfg(feature = "std")] #[tokio::test] + #[cfg(feature = "std")] async fn chain_from_iter_with_fds() -> crate::Result<()> { use crate::{ connection::socket::{ReadHalf, WriteHalf}, diff --git a/zlink-core/src/connection/tests/connection_tests.rs b/zlink-core/src/connection/tests/connection_tests.rs index 68d9ee80..3a710e78 100644 --- a/zlink-core/src/connection/tests/connection_tests.rs +++ b/zlink-core/src/connection/tests/connection_tests.rs @@ -2,8 +2,8 @@ use crate::{test_utils::mock_socket::MockSocket, Connection}; -#[cfg(feature = "std")] #[tokio::test] +#[cfg(feature = "std")] async fn peer_credentials_mock_socket() { // Get the expected credentials of the current process. let expected_uid = rustix::process::getuid(); diff --git a/zlink-core/src/varlink_service/proxy.rs b/zlink-core/src/varlink_service/proxy.rs index cbc8c127..f71cb160 100644 --- a/zlink-core/src/varlink_service/proxy.rs +++ b/zlink-core/src/varlink_service/proxy.rs @@ -84,12 +84,12 @@ use super::{Error, Info, InterfaceDescription, OwnedError, OwnedInfo}; /// # Ok(()) /// # } /// ``` -#[cfg(feature = "std")] #[proxy( interface = "org.varlink.service", crate = "crate", chain_name = "Chain" )] +#[cfg(feature = "std")] pub trait Proxy { /// Get information about a Varlink service. /// From 0ae5fe85a82f0d5b5cadee0ae19529e8cd610f6e Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 3 Feb 2026 20:44:03 +0100 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20core,macros:=20Support=20FDs=20?= =?UTF-8?q?in=20Service=20impls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This includes both the `Service` trait itself and the `service` macro. The macro API is almost the same as the `proxy` macro already provides. --- README.md | 39 ++- zlink-core/src/server/mod.rs | 39 ++- zlink-core/src/server/service.rs | 35 ++- zlink-macros/src/service/codegen.rs | 359 ++++++++++++++++++++---- zlink-macros/src/service/method.rs | 139 +++++++++- zlink-macros/src/service/types.rs | 3 + zlink/tests/service-macro.rs | 408 ++++++++++++++++++++++++++++ 7 files changed, 944 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index e710c442..0ed3c2ea 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ use serde::{Deserialize, Serialize}; use tokio::{select, sync::oneshot, fs::remove_file}; use zlink::{ proxy, - service::{MethodReply, Service}, + service::{self, MethodReply, Service}, connection::{Connection, Socket}, unix, Call, ReplyError, Server, }; @@ -196,7 +196,7 @@ where where Self: 'ser; type ReplyStreamParams = (); - type ReplyStream = futures_util::stream::Empty>; + type ReplyStream = futures_util::stream::Empty>; type ReplyError<'ser> = CalculatorError<'ser> where Self: 'ser; @@ -205,15 +205,25 @@ where &'service mut self, call: &'service Call>, conn: &mut Connection, - ) -> MethodReply, Self::ReplyStream, Self::ReplyError<'service>> { - match call.method() { + fds: Vec, + ) -> service::HandleResult< + Self::ReplyParams<'service>, + Self::ReplyStream, + Self::ReplyError<'service>, + > { + let _ = (conn, fds); + let reply = match call.method() { CalculatorMethod::Add { a, b } => { self.operations.push(format!("add({}, {})", a, b)); - MethodReply::Single(Some(CalculatorReply::Result(CalculationResult { result: a + b }))) + MethodReply::Single(Some(CalculatorReply::Result( + CalculationResult { result: a + b }, + ))) } CalculatorMethod::Multiply { x, y } => { self.operations.push(format!("multiply({}, {})", x, y)); - MethodReply::Single(Some(CalculatorReply::Result(CalculationResult { result: x * y }))) + MethodReply::Single(Some(CalculatorReply::Result( + CalculationResult { result: x * y }, + ))) } CalculatorMethod::Divide { dividend, divisor } => { if *divisor == 0.0 { @@ -226,20 +236,25 @@ where reason: "must be within range", }) } else { - self.operations.push(format!("divide({}, {})", dividend, divisor)); - MethodReply::Single(Some(CalculatorReply::Result(CalculationResult { - result: dividend / divisor, - }))) + self.operations + .push(format!("divide({}, {})", dividend, divisor)); + MethodReply::Single(Some(CalculatorReply::Result( + CalculationResult { + result: dividend / divisor, + }, + ))) } } CalculatorMethod::GetStats => { - let ops: Vec<&str> = self.operations.iter().map(|s| s.as_str()).collect(); + let ops: Vec<&str> = + self.operations.iter().map(|s| s.as_str()).collect(); MethodReply::Single(Some(CalculatorReply::Stats(Statistics { count: self.operations.len() as u64, operations: ops, }))) } - } + }; + (reply, Vec::new()) } } diff --git a/zlink-core/src/server/mod.rs b/zlink-core/src/server/mod.rs index 0b34c827..c883b8b4 100644 --- a/zlink-core/src/server/mod.rs +++ b/zlink-core/src/server/mod.rs @@ -109,7 +109,10 @@ where // 2. Read method calls from the existing connections and handle them. (idx, result) = read_select_all.fuse() => { #[cfg(feature = "std")] - let call = result.map(|(call, _fds)| call); + let (call, fds) = match result { + Ok((call, fds)) => (Ok(call), fds), + Err(e) => (Err(e), alloc::vec![]), + }; #[cfg(not(feature = "std"))] let call = result; last_method_call_winner = Some(idx); @@ -118,7 +121,13 @@ where let mut remove = true; match call { Ok(call) => { - match self.handle_call(call, &mut connections[idx]).await { + #[cfg(feature = "std")] + let result = + self.handle_call(call, &mut connections[idx], fds).await; + #[cfg(not(feature = "std"))] + let result = + self.handle_call(call, &mut connections[idx]).await; + match result { Ok(None) => remove = false, Ok(Some(s)) => stream = Some(s), Err(e) => warn!("Error writing to connection: {:?}", e), @@ -137,15 +146,20 @@ where } // 3. Read replies from the reply streams and send them off. reply = reply_stream_select_all.fuse() => { - let (idx, reply) = reply; + let (idx, item) = reply; last_reply_stream_winner = Some(idx); let id = reply_streams[idx].conn.id(); - match reply { - Some(reply) => { + match item { + Some(item) => { + #[cfg(feature = "std")] + let (reply, fds) = item; + #[cfg(not(feature = "std"))] + let reply = item; + #[cfg(feature = "std")] let send_result = - reply_streams[idx].conn.send_reply(&reply, alloc::vec![]).await; + reply_streams[idx].conn.send_reply(&reply, fds).await; #[cfg(not(feature = "std"))] let send_result = reply_streams[idx].conn.send_reply(&reply).await; if let Err(e) = send_result { @@ -168,20 +182,27 @@ where &mut self, call: Call>, conn: &mut Connection, + #[cfg(feature = "std")] fds: Vec, ) -> crate::Result> { let mut stream = None; - match self.service.handle(&call, conn).await { + + #[cfg(feature = "std")] + let (reply, reply_fds) = self.service.handle(&call, conn, fds).await; + #[cfg(not(feature = "std"))] + let reply = self.service.handle(&call, conn).await; + + match reply { // Don't send replies or errors for oneway calls. MethodReply::Single(_) | MethodReply::Error(_) if call.oneway() => (), MethodReply::Single(params) => { let reply = Reply::new(params).set_continues(Some(false)); #[cfg(feature = "std")] - conn.send_reply(&reply, alloc::vec![]).await?; + conn.send_reply(&reply, reply_fds).await?; #[cfg(not(feature = "std"))] conn.send_reply(&reply).await?; } #[cfg(feature = "std")] - MethodReply::Error(err) => conn.send_error(&err, alloc::vec![]).await?, + MethodReply::Error(err) => conn.send_error(&err, reply_fds).await?, #[cfg(not(feature = "std"))] MethodReply::Error(err) => conn.send_error(&err).await?, MethodReply::Multi(s) => { diff --git a/zlink-core/src/server/service.rs b/zlink-core/src/server/service.rs index 4968b1c2..57f65fb7 100644 --- a/zlink-core/src/server/service.rs +++ b/zlink-core/src/server/service.rs @@ -7,6 +7,19 @@ use serde::{Deserialize, Serialize}; use crate::{connection::Socket, Call, Connection, Reply}; +/// The item type that a [`Service::ReplyStream`] yields. +/// +/// On `std`, this is a tuple of the reply and the file descriptors to send with it. On `no_std`, +/// this is just the reply. +#[cfg(feature = "std")] +pub type ReplyStreamItem = (Reply, Vec); +/// The item type that a [`Service::ReplyStream`] yields. +/// +/// On `std`, this is a tuple of the reply and the file descriptors to send with it. On `no_std`, +/// this is just the reply. +#[cfg(not(feature = "std"))] +pub type ReplyStreamItem = Reply; + /// Service trait for handling method calls. pub trait Service where @@ -32,7 +45,7 @@ where /// The type of the multi-reply stream. /// /// If the client asks for multiple replies, this stream will be used to send them. - type ReplyStream: Stream> + Unpin; + type ReplyStream: Stream> + Unpin; /// The type of the error reply. /// /// This should be a type that can serialize itself to the whole reply object, containing @@ -48,11 +61,29 @@ where &'ser mut self, method: &'ser Call>, conn: &mut Connection, + #[cfg(feature = "std")] fds: Vec, ) -> impl Future< - Output = MethodReply, Self::ReplyStream, Self::ReplyError<'ser>>, + Output = HandleResult, Self::ReplyStream, Self::ReplyError<'ser>>, >; } +/// The result of a [`Service::handle`] call. +/// +/// On `std`, this is a tuple of the method reply and the file descriptors to send with it. On +/// `no_std`, this is just the method reply. +#[cfg(feature = "std")] +pub type HandleResult = ( + MethodReply, + Vec, +); +/// The result of a [`Service::handle`] call. +/// +/// On `std`, this is a tuple of the method reply and the file descriptors to send with it. On +/// `no_std`, this is just the method reply. +#[cfg(not(feature = "std"))] +pub type HandleResult = + MethodReply; + /// A service method call reply. #[derive(Debug)] pub enum MethodReply { diff --git a/zlink-macros/src/service/codegen.rs b/zlink-macros/src/service/codegen.rs index 40bc3c62..51c07b3f 100644 --- a/zlink-macros/src/service/codegen.rs +++ b/zlink-macros/src/service/codegen.rs @@ -176,7 +176,7 @@ pub(super) fn generate_service_impl( quote! { ::std::boxed::Box< dyn ::futures_util::Stream< - Item = #crate_path::Reply<#reply_stream_params_name> + Item = #crate_path::service::ReplyStreamItem<#reply_stream_params_name> > + ::core::marker::Unpin > }, @@ -201,12 +201,18 @@ pub(super) fn generate_service_impl( } else { // No streaming methods - use empty stream. ( - quote! { ::futures_util::stream::Empty<#crate_path::Reply<()>> }, + quote! { ::futures_util::stream::Empty<#crate_path::service::ReplyStreamItem<()>> }, quote! { () }, quote! {}, ) }; + // Generate the FDs parameter for the handle method (std only). + #[cfg(feature = "std")] + let handle_fds_param = quote! { __zlink_fds: ::std::vec::Vec<::std::os::fd::OwnedFd>, }; + #[cfg(not(feature = "std"))] + let handle_fds_param = quote! {}; + // Generate the impl block. let service_impl = quote! { impl #generics #crate_path::Service<#socket_ty> for #self_ty @@ -222,7 +228,8 @@ pub(super) fn generate_service_impl( &'__zlink_ser mut self, __zlink_call: &'__zlink_ser #crate_path::Call>, __zlink_conn: &mut #crate_path::Connection<#socket_ty>, - ) -> #crate_path::service::MethodReply< + #handle_fds_param + ) -> #crate_path::service::HandleResult< Self::ReplyParams<'__zlink_ser>, Self::ReplyStream, Self::ReplyError<'__zlink_ser>, @@ -585,6 +592,48 @@ fn generate_reply_stream_enum( .collect(); // Generate poll_next match arms. + // On std, stream items are (Reply, Vec) tuples. + // On no_std, stream items are just Reply. + #[cfg(feature = "std")] + let poll_arms: Vec = streaming_methods + .iter() + .map(|method| { + let variant_name = format_ident!("{}", method.varlink_name); + let item_type = method.stream_item_type.as_ref().unwrap(); + let type_str = item_type.to_token_stream().to_string(); + let params_variant = stream_item_type_map + .get(&type_str) + .cloned() + .unwrap_or_else(|| format_ident!("__Unknown")); + + if method.return_fds { + // Method's stream yields (Reply, Vec). + quote! { + #projection_name::#variant_name { stream } => { + stream.poll_next(cx).map(|opt| opt.map(|(reply, fds)| { + let mapped_reply = reply.map(|params| { + #reply_stream_params_name::#params_variant(params) + }); + (mapped_reply, fds) + })) + } + } + } else { + // Method's stream yields Reply, wrap with empty FDs. + quote! { + #projection_name::#variant_name { stream } => { + stream.poll_next(cx).map(|opt| opt.map(|reply| { + let mapped_reply = reply.map(|params| { + #reply_stream_params_name::#params_variant(params) + }); + (mapped_reply, ::std::vec::Vec::new()) + })) + } + } + } + }) + .collect(); + #[cfg(not(feature = "std"))] let poll_arms: Vec = streaming_methods .iter() .map(|method| { @@ -616,7 +665,7 @@ fn generate_reply_stream_enum( } impl ::futures_util::Stream for #enum_name { - type Item = #crate_path::Reply<#reply_stream_params_name>; + type Item = #crate_path::service::ReplyStreamItem<#reply_stream_params_name>; fn poll_next( self: ::core::pin::Pin<&mut Self>, @@ -831,6 +880,46 @@ fn generate_interface_descriptions( quote! { #(#descriptions)* } } +/// Wrap a `MethodReply` expression into a `HandleResult` with no file descriptors. +/// +/// On std: `(method_reply, Vec::new())` +/// On no_std: `method_reply` +fn wrap_handle_result_no_fds(inner: TokenStream) -> TokenStream { + #[cfg(feature = "std")] + { + quote! { + { + let __method_reply = { #inner }; + (__method_reply, ::std::vec::Vec::new()) + } + } + } + #[cfg(not(feature = "std"))] + { + inner + } +} + +/// Wrap a `MethodReply` expression into a `HandleResult` with the given file descriptors. +/// +/// On std: `(method_reply, fds_expr)` +/// On no_std: `method_reply` +#[cfg(feature = "std")] +fn wrap_handle_result_with_fds(inner: TokenStream, fds_expr: TokenStream) -> TokenStream { + quote! { + { + let __method_reply = { #inner }; + (__method_reply, #fds_expr) + } + } +} + +/// On no_std, FDs are ignored. +#[cfg(not(feature = "std"))] +fn wrap_handle_result_with_fds(inner: TokenStream, _fds_expr: TokenStream) -> TokenStream { + inner +} + /// Generate the `handle` method body with match arms. fn generate_handle_body( methods_info: &[MethodInfo], @@ -903,11 +992,22 @@ fn generate_handle_body( }) .collect(); + // Set up bindings for FD params. + let fds_bindings: Vec = method + .params + .iter() + .filter(|p| p.is_fds) + .map(|p| { + let name = &p.name; + quote! { let #name = __zlink_fds; } + }) + .collect(); + // Set up bindings for regular params (clone from pattern match). let param_bindings: Vec = method .params .iter() - .filter(|p| !p.is_connection && !p.is_more) + .filter(|p| !p.is_connection && !p.is_more && !p.is_fds) .map(|p| { let name = &p.name; quote! { let #name = ::core::clone::Clone::clone(#name); } @@ -918,6 +1018,7 @@ fn generate_handle_body( { #(#conn_bindings)* #(#more_bindings)* + #(#fds_bindings)* #(#param_bindings)* async move #body }.await @@ -931,6 +1032,9 @@ fn generate_handle_body( if p.is_more { // `more` param comes from the call request. quote! { __zlink_call.more() } + } else if p.is_fds { + // FD param gets the incoming FDs. + quote! { __zlink_fds } } else { // Clone from the pattern match. let name = &p.name; @@ -943,6 +1047,7 @@ fn generate_handle_body( }; // Build the return expression based on method type. + // Each branch must produce a `HandleResult`, not a raw `MethodReply`. let return_expr = if method.is_streaming { // Streaming method - wrap the stream to convert items to the enum type. // The stream produces Reply items, we map to Reply. @@ -961,11 +1066,55 @@ fn generate_handle_body( // Use the method's varlink name for the enum variant. let method_variant_name = format_ident!("{}", method.varlink_name); - if *needs_stream_boxing { + let streaming_reply = if *needs_stream_boxing { // Use boxing when any streaming method uses `impl Trait`. - quote! { + // On std, stream items are (Reply, Vec) tuples. + #[cfg(feature = "std")] + let boxed_stream = if method.return_fds { + // Method's stream yields (Reply, Vec). + quote! { + let __stream = #method_call; + let __mapped = ::futures_util::StreamExt::map( + __stream, + |(__reply, __fds)| { + let __mapped_reply = __reply.map(|__params| { + #reply_stream_params_name::#stream_variant_name(__params) + }); + (__mapped_reply, __fds) + }, + ); + let __boxed: ::std::boxed::Box< + dyn ::futures_util::Stream< + Item = #crate_path::service::ReplyStreamItem< + #reply_stream_params_name + > + > + ::core::marker::Unpin + > = ::std::boxed::Box::new(__mapped); + #crate_path::service::MethodReply::Multi(__boxed) + } + } else { + // Method's stream yields Reply, wrap with empty FDs. + quote! { + let __stream = #method_call; + let __mapped = ::futures_util::StreamExt::map(__stream, |__reply| { + let __mapped_reply = __reply.map(|__params| { + #reply_stream_params_name::#stream_variant_name(__params) + }); + (__mapped_reply, ::std::vec::Vec::new()) + }); + let __boxed: ::std::boxed::Box< + dyn ::futures_util::Stream< + Item = #crate_path::service::ReplyStreamItem< + #reply_stream_params_name + > + > + ::core::marker::Unpin + > = ::std::boxed::Box::new(__mapped); + #crate_path::service::MethodReply::Multi(__boxed) + } + }; + #[cfg(not(feature = "std"))] + let boxed_stream = quote! { let __stream = #method_call; - // Map each Reply to Reply using the enum variant. let __mapped = ::futures_util::StreamExt::map(__stream, |__reply| { __reply.map(|__params| { #reply_stream_params_name::#stream_variant_name(__params) @@ -973,20 +1122,109 @@ fn generate_handle_body( }); let __boxed: ::std::boxed::Box< dyn ::futures_util::Stream< - Item = #crate_path::Reply<#reply_stream_params_name> + Item = #crate_path::service::ReplyStreamItem< + #reply_stream_params_name + > > + ::core::marker::Unpin > = ::std::boxed::Box::new(__mapped); #crate_path::service::MethodReply::Multi(__boxed) - } + }; + boxed_stream } else { // Use enum variant when all streaming methods return concrete types. - // The mapping is done in the enum's Stream impl. quote! { let __stream = #method_call; #crate_path::service::MethodReply::Multi( #reply_stream_name::#method_variant_name { stream: __stream } ) } + }; + wrap_handle_result_no_fds(streaming_reply) + } else if method.return_fds && method.returns_result { + // return_fds + (Result, Vec). + // FDs are available on both Ok and Err arms. + let error_variant = method.error_type.as_ref().map(|err_ty| { + let type_str = err_ty.to_token_stream().to_string(); + let variant_idx = error_type_map.get(&type_str).copied().unwrap_or(0); + format_ident!("__{}Variant{}", reply_error_name, variant_idx) + }); + let error_convert = if let Some(ref err_variant) = error_variant { + quote! { #reply_error_name::#err_variant(__err) } + } else { + quote! { ::core::convert::From::from(__err) } + }; + let err_reply = quote! { + #crate_path::service::MethodReply::Error(#error_convert) + }; + let err_arm = wrap_handle_result_with_fds(err_reply, quote! { __out_fds }); + + if let Some(ref return_type) = method.return_type { + let type_str = return_type.to_token_stream().to_string(); + let variant_idx = type_to_variant.get(&type_str).copied().unwrap_or(0); + let reply_variant_name = + format_ident!("__{}Variant{}", method_call_name, variant_idx); + let ok_reply = quote! { + #crate_path::service::MethodReply::Single(Some( + #reply_params_name::#reply_variant_name(__ok) + )) + }; + let ok_arm = wrap_handle_result_with_fds(ok_reply, quote! { __out_fds }); + quote! { + let (__result, __out_fds) = #method_call; + match __result { + ::core::result::Result::Ok(__ok) => { + #ok_arm + } + ::core::result::Result::Err(__err) => { + #err_arm + } + } + } + } else { + // (Result<(), E>, Vec). + let ok_reply = quote! { + #crate_path::service::MethodReply::Single(None) + }; + let ok_arm = wrap_handle_result_with_fds(ok_reply, quote! { __out_fds }); + quote! { + let (__result, __out_fds) = #method_call; + match __result { + ::core::result::Result::Ok(()) => { + #ok_arm + } + ::core::result::Result::Err(__err) => { + #err_arm + } + } + } + } + } else if method.return_fds { + // return_fds without Result: (T, Vec). + if let Some(ref return_type) = method.return_type { + let type_str = return_type.to_token_stream().to_string(); + let variant_idx = type_to_variant.get(&type_str).copied().unwrap_or(0); + let reply_variant_name = + format_ident!("__{}Variant{}", method_call_name, variant_idx); + let ok_reply = quote! { + #crate_path::service::MethodReply::Single(Some( + #reply_params_name::#reply_variant_name(__ok) + )) + }; + let ok_arm = wrap_handle_result_with_fds(ok_reply, quote! { __out_fds }); + quote! { + let (__ok, __out_fds) = #method_call; + #ok_arm + } + } else { + // ((), Vec). + let ok_reply = quote! { + #crate_path::service::MethodReply::Single(None) + }; + let ok_arm = wrap_handle_result_with_fds(ok_reply, quote! { __out_fds }); + quote! { + let ((), __out_fds) = #method_call; + #ok_arm + } } } else if method.returns_result { // Method returns Result. Get the error variant for this method's error type. @@ -1007,7 +1245,7 @@ fn generate_handle_body( } else { quote! { ::core::convert::From::from(__err) } }; - quote! { + wrap_handle_result_no_fds(quote! { match #method_call { ::core::result::Result::Ok(__ok) => { #crate_path::service::MethodReply::Single(Some( @@ -1018,7 +1256,7 @@ fn generate_handle_body( #crate_path::service::MethodReply::Error(#error_convert) } } - } + }) } else { // Result<(), E>. let error_convert = if let Some(err_variant) = error_variant { @@ -1026,7 +1264,7 @@ fn generate_handle_body( } else { quote! { ::core::convert::From::from(__err) } }; - quote! { + wrap_handle_result_no_fds(quote! { match #method_call { ::core::result::Result::Ok(()) => { #crate_path::service::MethodReply::Single(None) @@ -1035,25 +1273,25 @@ fn generate_handle_body( #crate_path::service::MethodReply::Error(#error_convert) } } - } + }) } } else if let Some(ref return_type) = method.return_type { // Method returns T directly (not a Result). let type_str = return_type.to_token_stream().to_string(); let variant_idx = type_to_variant.get(&type_str).copied().unwrap_or(0); let reply_variant_name = format_ident!("__{}Variant{}", method_call_name, variant_idx); - quote! { + wrap_handle_result_no_fds(quote! { let __result = #method_call; #crate_path::service::MethodReply::Single(Some( #reply_params_name::#reply_variant_name(__result) )) - } + }) } else { // Method has no return type. - quote! { + wrap_handle_result_no_fds(quote! { let _ = #method_call; #crate_path::service::MethodReply::Single(None) - } + }) }; user_match_arms.push(quote! { @@ -1079,15 +1317,18 @@ fn generate_handle_body( ); // Add a catch-all arm for unknown methods (returns MethodNotFound error). + let unknown_method_reply = wrap_handle_result_no_fds(quote! { + #crate_path::service::MethodReply::Error( + #reply_error_name::#varlink_error_variant( + #crate_path::varlink_service::Error::MethodNotFound { + method: ::std::borrow::Cow::Borrowed("unknown"), + } + ) + ) + }); user_match_arms.push(quote! { #user_methods_name::#unknown_variant => { - #crate_path::service::MethodReply::Error( - #reply_error_name::#varlink_error_variant( - #crate_path::varlink_service::Error::MethodNotFound { - method: ::std::borrow::Cow::Borrowed("unknown"), - } - ) - ) + #unknown_method_reply } }); @@ -1109,14 +1350,18 @@ fn generate_handle_body( type_name.to_uppercase(), interface.replace('.', "_").to_uppercase() ); + let desc_reply = wrap_handle_result_no_fds(quote! { + #crate_path::service::MethodReply::Single(Some( + #reply_params_name::#varlink_reply_variant( + #crate_path::varlink_service::Reply::InterfaceDescription(desc) + ) + )) + }); quote! { #interface => { - let desc = #crate_path::varlink_service::InterfaceDescription::from(#const_name); - #crate_path::service::MethodReply::Single(Some( - #reply_params_name::#varlink_reply_variant( - #crate_path::varlink_service::Reply::InterfaceDescription(desc) - ) - )) + let desc = + #crate_path::varlink_service::InterfaceDescription::from(#const_name); + #desc_reply } } }) @@ -1133,6 +1378,29 @@ fn generate_handle_body( let url = service_attrs.url.as_deref().unwrap_or(""); // Generate the varlink service methods match. + let get_info_reply = wrap_handle_result_no_fds(quote! { + #crate_path::service::MethodReply::Single(Some( + #reply_params_name::#varlink_reply_variant( + #crate_path::varlink_service::Reply::Info(info) + ) + )) + }); + let varlink_desc_reply = wrap_handle_result_no_fds(quote! { + #crate_path::service::MethodReply::Single(Some( + #reply_params_name::#varlink_reply_variant( + #crate_path::varlink_service::Reply::InterfaceDescription(desc) + ) + )) + }); + let interface_not_found_reply = wrap_handle_result_no_fds(quote! { + #crate_path::service::MethodReply::Error( + #reply_error_name::#varlink_error_variant( + #crate_path::varlink_service::Error::InterfaceNotFound { + interface: ::std::borrow::Cow::Borrowed(interface), + } + ) + ) + }); let varlink_service_match = quote! { #method_call_name::__VarlinkService(__varlink_method) => { match __varlink_method { @@ -1147,33 +1415,20 @@ fn generate_handle_body( #crate_path::varlink_service::INTERFACE_NAME, ], ); - #crate_path::service::MethodReply::Single(Some( - #reply_params_name::#varlink_reply_variant( - #crate_path::varlink_service::Reply::Info(info) - ) - )) + #get_info_reply } #crate_path::varlink_service::Method::GetInterfaceDescription { interface } => { match *interface { #(#interface_match_arms)* #crate_path::varlink_service::INTERFACE_NAME => { - let desc = #crate_path::varlink_service::InterfaceDescription::from( - #crate_path::varlink_service::DESCRIPTION - ); - #crate_path::service::MethodReply::Single(Some( - #reply_params_name::#varlink_reply_variant( - #crate_path::varlink_service::Reply::InterfaceDescription(desc) - ) - )) + let desc = + #crate_path::varlink_service::InterfaceDescription::from( + #crate_path::varlink_service::DESCRIPTION + ); + #varlink_desc_reply } _ => { - #crate_path::service::MethodReply::Error( - #reply_error_name::#varlink_error_variant( - #crate_path::varlink_service::Error::InterfaceNotFound { - interface: ::std::borrow::Cow::Borrowed(interface), - } - ) - ) + #interface_not_found_reply } } } diff --git a/zlink-macros/src/service/method.rs b/zlink-macros/src/service/method.rs index 4ca0ee4d..4193d4ba 100644 --- a/zlink-macros/src/service/method.rs +++ b/zlink-macros/src/service/method.rs @@ -34,6 +34,8 @@ pub(super) struct MethodInfo { pub stream_return_type: Option, /// Whether the streaming method returns `impl Trait` (requires boxing). pub stream_uses_impl_trait: bool, + /// Whether this method returns file descriptors (`#[zlink(return_fds)]`). + pub return_fds: bool, } impl MethodInfo { @@ -75,6 +77,7 @@ impl MethodInfo { let param_attrs = extract_param_attrs(&pat_type.attrs); param_info.serialized_name = param_attrs.rename; param_info.is_connection = param_attrs.is_connection; + param_info.is_fds = param_attrs.is_fds; } // For streaming methods, the first param must be `more: bool`. if is_streaming && idx == 0 { @@ -100,6 +103,16 @@ impl MethodInfo { } } + // Validate FD attributes. + let return_fds = method_attrs.return_fds; + let fds_params: Vec<_> = params.iter().filter(|p| p.is_fds).collect(); + if fds_params.len() > 1 { + return Err(Error::new_spanned( + &method.sig, + "at most one `#[zlink(fds)]` parameter is allowed per method", + )); + } + // Extract return type and check if it's a Result or Stream. let ( return_type, @@ -108,7 +121,57 @@ impl MethodInfo { stream_item_type, stream_return_type, stream_uses_impl_trait, - ) = if is_streaming { + ) = if is_streaming && return_fds { + // For streaming methods with FD passing, the stream yields (Reply, Vec). + match &method.sig.output { + ReturnType::Default => { + return Err(Error::new_spanned( + &method.sig, + "streaming methods with return_fds must return \ + a Stream, Vec)>", + )) + } + ReturnType::Type(_, ty) => { + let stream_item = extract_stream_item_type(ty).ok_or_else(|| { + Error::new_spanned( + ty, + "streaming methods with return_fds must return \ + a Stream, Vec)> \ + (could not extract Stream's Item type)", + ) + })?; + // Extract Reply from (Reply, Vec). + let reply_type = + extract_first_tuple_element(&stream_item).ok_or_else(|| { + Error::new_spanned( + ty, + "streaming methods with return_fds must return \ + a Stream, Vec)> \ + (stream item must be a tuple)", + ) + })?; + // Extract T from Reply. + let inner_type = extract_reply_inner_type(&reply_type).ok_or_else(|| { + Error::new_spanned( + ty, + "streaming methods with return_fds must return \ + a Stream, Vec)> \ + (first tuple element must be Reply)", + ) + })?; + // Check if return type uses `impl Trait`. + let uses_impl_trait = matches!(**ty, Type::ImplTrait(_)); + ( + None, + false, + None, + Some(inner_type), + Some((**ty).clone()), + uses_impl_trait, + ) + } + } + } else if is_streaming { // For streaming methods, extract the Stream's Item type. // Streaming methods can return either: // - `impl Stream>` (will use boxing) @@ -148,6 +211,42 @@ impl MethodInfo { ) } } + } else if return_fds { + // For return_fds methods, the return type is a tuple whose second element + // is `Vec`. The first element is either: + // - `Result` → `(Result, Vec)` — extract T and E + // - `T` → `(T, Vec)` — extract T, no error type + match &method.sig.output { + ReturnType::Default => { + return Err(Error::new_spanned( + &method.sig, + "`return_fds` methods must have a return type", + )) + } + ReturnType::Type(_, ty) => { + // Extract the first element of the tuple. + let first = extract_first_tuple_element(ty).ok_or_else(|| { + Error::new_spanned( + ty, + "`return_fds` methods must return \ + `(T, Vec)` or `(Result, Vec)`", + ) + })?; + + if let Some((inner_ty, err_ty)) = extract_result_types(&first) { + // (Result, Vec). + (inner_ty, true, Some(err_ty), None, None, false) + } else { + // (T, Vec). + let data_ty = if is_unit_type(&first) { + None + } else { + Some(first) + }; + (data_ty, false, None, None, None, false) + } + } + } } else { // For non-streaming methods, extract Result types as before. match &method.sig.output { @@ -179,6 +278,7 @@ impl MethodInfo { stream_item_type, stream_return_type, stream_uses_impl_trait, + return_fds, }) } @@ -194,11 +294,11 @@ impl MethodInfo { self.params.iter().any(|p| p.is_connection) } - /// Get parameters that are serialized (excludes connection and more parameters). + /// Get parameters that are serialized (excludes connection, more, and fds parameters). pub(super) fn serialized_params(&self) -> impl Iterator { self.params .iter() - .filter(|p| !p.is_connection && !p.is_more) + .filter(|p| !p.is_connection && !p.is_more && !p.is_fds) } } @@ -211,6 +311,8 @@ struct MethodAttrs { rename: Option, /// Whether this method returns a stream of replies. is_streaming: bool, + /// Whether this method returns file descriptors. + return_fds: bool, } impl MethodAttrs { @@ -276,6 +378,15 @@ impl MethodAttrs { } result.is_streaming = true; } + Meta::Path(path) if path.is_ident("return_fds") => { + if result.return_fds { + return Err(Error::new_spanned( + &meta, + "duplicate `return_fds` attribute", + )); + } + result.return_fds = true; + } _ => { return Err(Error::new_spanned(&meta, "unknown zlink attribute")); } @@ -299,6 +410,8 @@ struct ParamAttrs { rename: Option, /// Whether this parameter should receive the connection. is_connection: bool, + /// Whether this parameter receives file descriptors. + is_fds: bool, } /// Extract zlink attributes from parameter attributes. @@ -332,6 +445,9 @@ fn extract_param_attrs(attrs: &[Attribute]) -> ParamAttrs { Meta::Path(path) if path.is_ident("connection") => { result.is_connection = true; } + Meta::Path(path) if path.is_ident("fds") => { + result.is_fds = true; + } _ => {} } } @@ -502,3 +618,20 @@ fn is_bool_type(ty: &Type) -> bool { }; type_path.path.is_ident("bool") } + +/// Check if a type is the unit type `()`. +fn is_unit_type(ty: &Type) -> bool { + let Type::Tuple(tuple) = ty else { + return false; + }; + tuple.elems.is_empty() +} + +/// Extract the first element type from a tuple type `(T, ...)`. +/// Returns `None` if the type is not a tuple with at least one element. +fn extract_first_tuple_element(ty: &Type) -> Option { + let Type::Tuple(tuple) = ty else { + return None; + }; + tuple.elems.first().cloned() +} diff --git a/zlink-macros/src/service/types.rs b/zlink-macros/src/service/types.rs index 0ae9e2d4..5dbef04a 100644 --- a/zlink-macros/src/service/types.rs +++ b/zlink-macros/src/service/types.rs @@ -13,6 +13,8 @@ pub(super) struct ParamInfo { pub serialized_name: Option, /// Whether this parameter is marked with `#[zlink(connection)]`. pub is_connection: bool, + /// Whether this parameter is marked with `#[zlink(fds)]`. + pub is_fds: bool, /// Whether this is the `more` parameter for streaming methods. pub is_more: bool, } @@ -32,6 +34,7 @@ impl ParamInfo { ty: (*pat_type.ty).clone(), serialized_name: None, is_connection: false, + is_fds: false, is_more: false, }) } diff --git a/zlink/tests/service-macro.rs b/zlink/tests/service-macro.rs index 68b8a702..24892e86 100644 --- a/zlink/tests/service-macro.rs +++ b/zlink/tests/service-macro.rs @@ -767,3 +767,411 @@ trait StreamingProxy { &mut self, ) -> zlink::Result>>>; } + +// ============================================================================ +// Test file descriptor passing with service macro +// ============================================================================ + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn service_macro_fd_passing() -> Result<(), Box> { + let socket_path = "/tmp/zlink-service-macro-fd-test.sock"; + if let Err(e) = tokio::fs::remove_file(socket_path).await { + if e.kind() != std::io::ErrorKind::NotFound { + return Err(e.into()); + } + } + + let listener = bind(socket_path).unwrap(); + let service = FdService; + let server = Server::new(listener, service); + tokio::select! { + res = server.run() => res?, + res = run_fd_client(socket_path) => res?, + } + + Ok(()) +} + +async fn run_fd_client(socket_path: &str) -> Result<(), Box> { + use std::{ + io::{Read, Write}, + os::unix::net::UnixStream, + }; + + let mut conn = connect(socket_path).await?; + + // Send multiple FDs and read from a specific one by index. + let (r0, mut w0) = UnixStream::pair()?; + let (r1, mut w1) = UnixStream::pair()?; + let (r2, mut w2) = UnixStream::pair()?; + w0.write_all(b"data-zero")?; + w1.write_all(b"data-one")?; + w2.write_all(b"data-two")?; + drop((w0, w1, w2)); + let fds = vec![r0.into(), r1.into(), r2.into()]; + // Read from index 1. + let data = conn.read_fd(1, fds).await?.unwrap(); + assert_eq!(data, "data-one"); + + // Invalid index returns an error. + let (r, mut w) = UnixStream::pair()?; + w.write_all(b"some data")?; + drop(w); + let result = conn.read_fd(5, vec![r.into()]).await?; + assert!(matches!(result, Err(FdError::InvalidIndex { index: 5 }))); + + // Receive FDs from the service. Each handle has a name and fd_index referencing the FD vector. + let names = vec!["config.txt".into(), "data.bin".into(), "log.txt".into()]; + let (result, fds) = conn.open_fds(names).await?; + let handles = result.unwrap(); + assert_eq!(handles.len(), 3); + assert_eq!(fds.len(), 3); + // Verify each handle's name and that the FD at fd_index contains the name as content. + for handle in &handles { + let fd = &fds[handle.fd_index as usize]; + let cloned_fd = fd.try_clone()?; + let mut stream = UnixStream::from(cloned_fd); + let mut buf = String::new(); + stream.read_to_string(&mut buf)?; + assert_eq!(buf, handle.name); + } + + // Receive zero FDs from the service. + let (result, fds) = conn.open_fds(Vec::new()).await?; + let handles = result.unwrap(); + assert!(handles.is_empty()); + assert!(fds.is_empty()); + + // Receive an FD on success path and verify the handle's index references the correct FD. + let (result, fds) = conn.try_open_fd("success.txt".into(), false).await?; + let handle = result.unwrap(); + assert_eq!(handle.name, "success.txt"); + assert_eq!(handle.fd_index, 0); + assert_eq!(fds.len(), 1); + let mut stream = UnixStream::from(fds.into_iter().next().unwrap()); + let mut buf = String::new(); + stream.read_to_string(&mut buf)?; + assert_eq!(buf, "success.txt"); + + // Receive an FD on error path and verify the diagnostic content. + let (result, fds) = conn.try_open_fd("missing.txt".into(), true).await?; + let err = result.unwrap_err(); + assert!(matches!(err, FdError::NotFound { name } if name == "missing.txt")); + assert_eq!(fds.len(), 1); + let mut stream = UnixStream::from(fds.into_iter().next().unwrap()); + let mut buf = String::new(); + stream.read_to_string(&mut buf)?; + assert_eq!(buf, "error-diagnostic"); + + Ok(()) +} + +// Response type for FD operations. The `fd_index` field references a position in the FD vector. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct FdHandle { + name: String, + fd_index: u32, +} + +// Error type for FD operations. +#[derive(Debug, Clone, PartialEq, zlink::ReplyError, introspect::ReplyError)] +#[zlink(interface = "org.example.fd")] +enum FdError { + InvalidIndex { index: u32 }, + NotFound { name: String }, +} + +// A service that tests file descriptor passing. +struct FdService; + +#[zlink::service(interface = "org.example.fd")] +impl FdService { + /// Receive FDs and read from the one at the given index. + async fn read_fd( + &self, + fd_index: u32, + #[zlink(fds)] fds: Vec, + ) -> Result { + use std::{io::Read, os::unix::net::UnixStream}; + + let Some(fd) = fds.into_iter().nth(fd_index as usize) else { + return Err(FdError::InvalidIndex { index: fd_index }); + }; + let mut stream = UnixStream::from(fd); + let mut buf = String::new(); + stream.read_to_string(&mut buf).unwrap(); + Ok(buf) + } + + /// Open a list of named FDs and return handles with their indexes. + #[zlink(return_fds)] + async fn open_fds(&self, names: Vec) -> (Vec, Vec) { + use std::{io::Write, os::unix::net::UnixStream}; + + let mut handles = Vec::new(); + let mut fds = Vec::new(); + for (i, name) in names.into_iter().enumerate() { + let (r, mut w) = UnixStream::pair().unwrap(); + // Write the name as the FD content for verification. + w.write_all(name.as_bytes()).unwrap(); + drop(w); + handles.push(FdHandle { + name, + fd_index: i as u32, + }); + fds.push(r.into()); + } + (handles, fds) + } + + /// Try to open an FD. On success, return the handle with its index. On error, return the + /// error alongside a diagnostic FD. + #[zlink(return_fds)] + async fn try_open_fd( + &self, + name: String, + should_fail: bool, + ) -> (Result, Vec) { + use std::{io::Write, os::unix::net::UnixStream}; + + let (r, mut w) = UnixStream::pair().unwrap(); + if should_fail { + w.write_all(b"error-diagnostic").unwrap(); + drop(w); + ( + Err(FdError::NotFound { name }), + vec![std::os::fd::OwnedFd::from(r)], + ) + } else { + w.write_all(name.as_bytes()).unwrap(); + drop(w); + ( + Ok(FdHandle { name, fd_index: 0 }), + vec![std::os::fd::OwnedFd::from(r)], + ) + } + } +} + +// Proxy for FD service. +#[zlink::proxy("org.example.fd")] +trait FdProxy { + async fn read_fd( + &mut self, + fd_index: u32, + #[zlink(fds)] fds: Vec, + ) -> zlink::Result>; + + #[zlink(return_fds)] + async fn open_fds( + &mut self, + names: Vec, + ) -> zlink::Result<(Result, FdError>, Vec)>; + + #[zlink(return_fds)] + async fn try_open_fd( + &mut self, + name: String, + should_fail: bool, + ) -> zlink::Result<(Result, Vec)>; +} + +// ============================================================================ +// Test streaming service methods with FD passing (#[zlink(more, return_fds)]) +// ============================================================================ + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn service_macro_streaming_with_fds() -> Result<(), Box> { + let socket_path = "/tmp/zlink-service-macro-streaming-fd-test.sock"; + if let Err(e) = tokio::fs::remove_file(socket_path).await { + if e.kind() != std::io::ErrorKind::NotFound { + return Err(e.into()); + } + } + + let listener = bind(socket_path).unwrap(); + let service = StreamingFdService; + let server = Server::new(listener, service); + tokio::select! { + res = server.run() => res?, + res = run_streaming_fd_client(socket_path) => res?, + } + + Ok(()) +} + +async fn run_streaming_fd_client(socket_path: &str) -> Result<(), Box> { + use futures_util::StreamExt; + use std::{ + io::{Read, Write}, + os::unix::net::UnixStream, + }; + + let mut conn = connect(socket_path).await?; + + // ========================================================================= + // Test 1: Stream output FDs (return_fds + more) + // ========================================================================= + { + let names = vec![ + "first".to_string(), + "second".to_string(), + "third".to_string(), + ]; + let mut stream = std::pin::pin!(conn.stream_fds(names).await?); + + // Collect all stream items. + let mut handles = Vec::new(); + let mut all_fds = Vec::new(); + while let Some(result) = stream.next().await { + let (result, fds) = result?; + let handle = result.unwrap(); + handles.push(handle); + all_fds.extend(fds); + } + + // Should have received 3 handles with 3 FDs. + assert_eq!(handles.len(), 3); + assert_eq!(all_fds.len(), 3); + + // Verify each handle's FD contains the expected content. + for (i, handle) in handles.iter().enumerate() { + assert_eq!(handle.fd_index, i as u32); + let fd = all_fds[handle.fd_index as usize].try_clone()?; + let mut stream = UnixStream::from(fd); + let mut buf = String::new(); + stream.read_to_string(&mut buf)?; + assert_eq!(buf, handle.name); + } + } + + // ========================================================================= + // Test 2: Stream input FDs (fds + more) + // ========================================================================= + { + // Create 3 FDs with known content. + let (r0, mut w0) = UnixStream::pair()?; + let (r1, mut w1) = UnixStream::pair()?; + let (r2, mut w2) = UnixStream::pair()?; + w0.write_all(b"content-zero")?; + w1.write_all(b"content-one")?; + w2.write_all(b"content-two")?; + drop((w0, w1, w2)); + + let fds = vec![r0.into(), r1.into(), r2.into()]; + let mut stream = std::pin::pin!(conn.read_fds_streaming(fds).await?); + + // Collect all stream items. + let mut results = Vec::new(); + while let Some(result) = stream.next().await { + let read_result = result?.unwrap(); + results.push(read_result); + } + + // Should have received 3 results. + assert_eq!(results.len(), 3); + + // Verify each result has the expected content. + assert_eq!(results[0].fd_index, 0); + assert_eq!(results[0].content, "content-zero"); + assert_eq!(results[1].fd_index, 1); + assert_eq!(results[1].content, "content-one"); + assert_eq!(results[2].fd_index, 2); + assert_eq!(results[2].content, "content-two"); + } + + Ok(()) +} + +/// Response for streaming FD read operations. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct FdReadResult { + fd_index: u32, + content: String, +} + +/// A service that streams file descriptors. +struct StreamingFdService; + +#[zlink::service(interface = "org.example.streaming_fd")] +impl StreamingFdService { + /// Stream FDs with handles, one per name. Each stream item contains a handle and the FD. + #[zlink(more, return_fds)] + async fn stream_fds( + &self, + more: bool, + names: Vec, + ) -> impl futures_util::Stream, Vec)> + Unpin + { + use std::{io::Write, os::unix::net::UnixStream}; + + // If more=false, only return the first item. + let names: Vec = if more { + names + } else { + names.into_iter().take(1).collect() + }; + + let n = names.len(); + futures_util::stream::iter(names.into_iter().enumerate().map(move |(i, name)| { + let (r, mut w) = UnixStream::pair().unwrap(); + w.write_all(name.as_bytes()).unwrap(); + drop(w); + let handle = FdHandle { + name, + fd_index: i as u32, + }; + let reply = zlink::Reply::new(Some(handle)).set_continues(Some(i < n - 1)); + (reply, vec![r.into()]) + })) + } + + /// Receive FDs and stream back the content read from each one. + #[zlink(more)] + async fn read_fds_streaming( + &self, + more: bool, + #[zlink(fds)] fds: Vec, + ) -> impl futures_util::Stream> + Unpin { + use std::{io::Read, os::unix::net::UnixStream}; + + // If more=false, only return the first result. + let fds: Vec = if more { + fds + } else { + fds.into_iter().take(1).collect() + }; + + let n = fds.len(); + futures_util::stream::iter(fds.into_iter().enumerate().map(move |(i, fd)| { + let mut stream = UnixStream::from(fd); + let mut content = String::new(); + stream.read_to_string(&mut content).unwrap(); + let result = FdReadResult { + fd_index: i as u32, + content, + }; + zlink::Reply::new(Some(result)).set_continues(Some(i < n - 1)) + })) + } +} + +/// Proxy for streaming FD service. +#[zlink::proxy("org.example.streaming_fd")] +trait StreamingFdProxy { + #[zlink(more, return_fds)] + async fn stream_fds( + &mut self, + names: Vec, + ) -> zlink::Result< + impl futures_util::Stream< + Item = zlink::Result<(Result, Vec)>, + >, + >; + + #[zlink(more)] + async fn read_fds_streaming( + &mut self, + #[zlink(fds)] fds: Vec, + ) -> zlink::Result>>>; +}