Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use serde::{Deserialize, Serialize};
use tokio::{select, sync::oneshot, fs::remove_file};
use zlink::{
proxy,
service::{MethodReply, Service},
service::{self, MethodReply, Service},
connection::{Connection, Socket},
unix, Call, ReplyError, Server,
};
Expand Down Expand Up @@ -196,7 +196,7 @@ where
where
Self: 'ser;
type ReplyStreamParams = ();
type ReplyStream = futures_util::stream::Empty<zlink::Reply<()>>;
type ReplyStream = futures_util::stream::Empty<zlink::service::ReplyStreamItem<()>>;
type ReplyError<'ser> = CalculatorError<'ser>
where
Self: 'ser;
Expand All @@ -205,15 +205,25 @@ where
&'service mut self,
call: &'service Call<Self::MethodCall<'_>>,
conn: &mut Connection<Sock>,
) -> MethodReply<Self::ReplyParams<'service>, Self::ReplyStream, Self::ReplyError<'service>> {
match call.method() {
fds: Vec<std::os::fd::OwnedFd>,
) -> service::HandleResult<
Self::ReplyParams<'service>,
Self::ReplyStream,
Self::ReplyError<'service>,
> {
let _ = (conn, fds);
let reply = match call.method() {
CalculatorMethod::Add { a, b } => {
self.operations.push(format!("add({}, {})", a, b));
MethodReply::Single(Some(CalculatorReply::Result(CalculationResult { result: a + b })))
MethodReply::Single(Some(CalculatorReply::Result(
CalculationResult { result: a + b },
)))
}
CalculatorMethod::Multiply { x, y } => {
self.operations.push(format!("multiply({}, {})", x, y));
MethodReply::Single(Some(CalculatorReply::Result(CalculationResult { result: x * y })))
MethodReply::Single(Some(CalculatorReply::Result(
CalculationResult { result: x * y },
)))
}
CalculatorMethod::Divide { dividend, divisor } => {
if *divisor == 0.0 {
Expand All @@ -226,20 +236,25 @@ where
reason: "must be within range",
})
} else {
self.operations.push(format!("divide({}, {})", dividend, divisor));
MethodReply::Single(Some(CalculatorReply::Result(CalculationResult {
result: dividend / divisor,
})))
self.operations
.push(format!("divide({}, {})", dividend, divisor));
MethodReply::Single(Some(CalculatorReply::Result(
CalculationResult {
result: dividend / divisor,
},
)))
}
}
CalculatorMethod::GetStats => {
let ops: Vec<&str> = self.operations.iter().map(|s| s.as_str()).collect();
let ops: Vec<&str> =
self.operations.iter().map(|s| s.as_str()).collect();
MethodReply::Single(Some(CalculatorReply::Stats(Statistics {
count: self.operations.len() as u64,
operations: ops,
})))
}
}
};
(reply, Vec::new())
}
}

Expand Down
2 changes: 1 addition & 1 deletion zlink-core/src/connection/chain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,8 @@ mod tests {
Ok(())
}

#[cfg(feature = "std")]
#[tokio::test]
#[cfg(feature = "std")]
async fn chain_from_iter_with_fds() -> crate::Result<()> {
use crate::{
connection::socket::{ReadHalf, WriteHalf},
Expand Down
2 changes: 1 addition & 1 deletion zlink-core/src/connection/tests/connection_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

use crate::{test_utils::mock_socket::MockSocket, Connection};

#[cfg(feature = "std")]
#[tokio::test]
#[cfg(feature = "std")]
async fn peer_credentials_mock_socket() {
// Get the expected credentials of the current process.
let expected_uid = rustix::process::getuid();
Expand Down
39 changes: 30 additions & 9 deletions zlink-core/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ where
// 2. Read method calls from the existing connections and handle them.
(idx, result) = read_select_all.fuse() => {
#[cfg(feature = "std")]
let call = result.map(|(call, _fds)| call);
let (call, fds) = match result {
Ok((call, fds)) => (Ok(call), fds),
Err(e) => (Err(e), alloc::vec![]),
};
#[cfg(not(feature = "std"))]
let call = result;
last_method_call_winner = Some(idx);
Expand All @@ -118,7 +121,13 @@ where
let mut remove = true;
match call {
Ok(call) => {
match self.handle_call(call, &mut connections[idx]).await {
#[cfg(feature = "std")]
let result =
self.handle_call(call, &mut connections[idx], fds).await;
#[cfg(not(feature = "std"))]
let result =
self.handle_call(call, &mut connections[idx]).await;
match result {
Ok(None) => remove = false,
Ok(Some(s)) => stream = Some(s),
Err(e) => warn!("Error writing to connection: {:?}", e),
Expand All @@ -137,15 +146,20 @@ where
}
// 3. Read replies from the reply streams and send them off.
reply = reply_stream_select_all.fuse() => {
let (idx, reply) = reply;
let (idx, item) = reply;
last_reply_stream_winner = Some(idx);
let id = reply_streams[idx].conn.id();

match reply {
Some(reply) => {
match item {
Some(item) => {
#[cfg(feature = "std")]
let (reply, fds) = item;
#[cfg(not(feature = "std"))]
let reply = item;

#[cfg(feature = "std")]
let send_result =
reply_streams[idx].conn.send_reply(&reply, alloc::vec![]).await;
reply_streams[idx].conn.send_reply(&reply, fds).await;
#[cfg(not(feature = "std"))]
let send_result = reply_streams[idx].conn.send_reply(&reply).await;
if let Err(e) = send_result {
Expand All @@ -168,20 +182,27 @@ where
&mut self,
call: Call<Service::MethodCall<'_>>,
conn: &mut Connection<Listener::Socket>,
#[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
) -> crate::Result<Option<Service::ReplyStream>> {
let mut stream = None;
match self.service.handle(&call, conn).await {

#[cfg(feature = "std")]
let (reply, reply_fds) = self.service.handle(&call, conn, fds).await;
#[cfg(not(feature = "std"))]
let reply = self.service.handle(&call, conn).await;

match reply {
// Don't send replies or errors for oneway calls.
MethodReply::Single(_) | MethodReply::Error(_) if call.oneway() => (),
MethodReply::Single(params) => {
let reply = Reply::new(params).set_continues(Some(false));
#[cfg(feature = "std")]
conn.send_reply(&reply, alloc::vec![]).await?;
conn.send_reply(&reply, reply_fds).await?;
#[cfg(not(feature = "std"))]
conn.send_reply(&reply).await?;
}
#[cfg(feature = "std")]
MethodReply::Error(err) => conn.send_error(&err, alloc::vec![]).await?,
MethodReply::Error(err) => conn.send_error(&err, reply_fds).await?,
#[cfg(not(feature = "std"))]
MethodReply::Error(err) => conn.send_error(&err).await?,
MethodReply::Multi(s) => {
Expand Down
35 changes: 33 additions & 2 deletions zlink-core/src/server/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ use serde::{Deserialize, Serialize};

use crate::{connection::Socket, Call, Connection, Reply};

/// The item type that a [`Service::ReplyStream`] yields.
///
/// On `std`, this is a tuple of the reply and the file descriptors to send with it. On `no_std`,
/// this is just the reply.
#[cfg(feature = "std")]
pub type ReplyStreamItem<Params> = (Reply<Params>, Vec<std::os::fd::OwnedFd>);
/// The item type that a [`Service::ReplyStream`] yields.
///
/// On `std`, this is a tuple of the reply and the file descriptors to send with it. On `no_std`,
/// this is just the reply.
#[cfg(not(feature = "std"))]
pub type ReplyStreamItem<Params> = Reply<Params>;

/// Service trait for handling method calls.
pub trait Service<Sock>
where
Expand All @@ -32,7 +45,7 @@ where
/// The type of the multi-reply stream.
///
/// If the client asks for multiple replies, this stream will be used to send them.
type ReplyStream: Stream<Item = Reply<Self::ReplyStreamParams>> + Unpin;
type ReplyStream: Stream<Item = ReplyStreamItem<Self::ReplyStreamParams>> + Unpin;
/// The type of the error reply.
///
/// This should be a type that can serialize itself to the whole reply object, containing
Expand All @@ -48,11 +61,29 @@ where
&'ser mut self,
method: &'ser Call<Self::MethodCall<'_>>,
conn: &mut Connection<Sock>,
#[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
) -> impl Future<
Output = MethodReply<Self::ReplyParams<'ser>, Self::ReplyStream, Self::ReplyError<'ser>>,
Output = HandleResult<Self::ReplyParams<'ser>, Self::ReplyStream, Self::ReplyError<'ser>>,
>;
}

/// The result of a [`Service::handle`] call.
///
/// On `std`, this is a tuple of the method reply and the file descriptors to send with it. On
/// `no_std`, this is just the method reply.
#[cfg(feature = "std")]
pub type HandleResult<Params, ReplyStream, ReplyError> = (
MethodReply<Params, ReplyStream, ReplyError>,
Vec<std::os::fd::OwnedFd>,
);
/// The result of a [`Service::handle`] call.
///
/// On `std`, this is a tuple of the method reply and the file descriptors to send with it. On
/// `no_std`, this is just the method reply.
#[cfg(not(feature = "std"))]
pub type HandleResult<Params, ReplyStream, ReplyError> =
MethodReply<Params, ReplyStream, ReplyError>;

/// A service method call reply.
#[derive(Debug)]
pub enum MethodReply<Params, ReplyStream, ReplyError> {
Expand Down
2 changes: 1 addition & 1 deletion zlink-core/src/varlink_service/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ use super::{Error, Info, InterfaceDescription, OwnedError, OwnedInfo};
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "std")]
#[proxy(
interface = "org.varlink.service",
crate = "crate",
chain_name = "Chain"
)]
#[cfg(feature = "std")]
pub trait Proxy {
/// Get information about a Varlink service.
///
Expand Down
Loading