diff --git a/dgram/src/lib.rs b/dgram/src/lib.rs index 416935c41f..9aeeb3cec4 100644 --- a/dgram/src/lib.rs +++ b/dgram/src/lib.rs @@ -173,10 +173,3 @@ mod linux_imports { pub(super) use std::net::SocketAddrV6; pub(super) use std::os::fd::AsRawFd; } - -#[cfg(feature = "async")] -mod async_imports { - pub(super) use std::io::ErrorKind; - pub(super) use tokio::io::Interest; - pub(super) use tokio::net::UdpSocket; -} diff --git a/dgram/src/tokio.rs b/dgram/src/tokio.rs index 4b04138efa..cd64d07787 100644 --- a/dgram/src/tokio.rs +++ b/dgram/src/tokio.rs @@ -1,7 +1,11 @@ use crate::RecvData; +use std::io::ErrorKind; use std::io::Result; +use std::task::Context; +use std::task::Poll; -use crate::async_imports::*; +use tokio::io::Interest; +use tokio::net::UdpSocket; #[cfg(target_os = "linux")] mod linux { @@ -13,64 +17,73 @@ mod linux { use linux::*; #[cfg(target_os = "linux")] -pub async fn send_to( - socket: &UdpSocket, send_buf: &[u8], send_msg_settings: SendMsgCmsgSettings, -) -> Result { - loop { - // Important to use try_io so that Tokio can clear the socket's readiness - // flag - let res = socket.try_io(Interest::WRITABLE, || { - let fd = socket.as_fd(); - send_msg(fd, send_buf, send_msg_settings).map_err(Into::into) - }); - - match res { - Err(e) if e.kind() == ErrorKind::WouldBlock => - socket.writable().await?, - res => return res, - } +pub fn poll_send_to( + socket: &UdpSocket, ctx: &mut Context<'_>, send_buf: &[u8], + sendmsg_settings: SendMsgSettings, +) -> Poll> { + // We manually poll the socket here to register interest in + // Writable socket events for the given `ctx`. + // Under the hood, tokio's implementation just checks for + // EWOULDBLOCK and, if the socket is busy, registers the provided + // waker to be invoked when the socket is free. + match socket.poll_send_ready(ctx) { + Poll::Ready(Ok(())) => { + // Important to use try_io so that Tokio can clear the socket's + // readiness flag + match socket.try_io(Interest::WRITABLE, || { + let fd = socket.as_fd(); + send_msg(fd, send_buf, sendmsg_settings).map_err(Into::into) + }) { + Ok(n) => Poll::Ready(Ok(n)), + Err(e) if e.kind() == ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } + }, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, } } #[cfg(target_os = "linux")] -pub async fn recv_from( - socket: &UdpSocket, read_buf: &mut [u8], msg_flags: Option, - store_cmsg_settings: &mut RecvMsgCmsgSettings, -) -> Result { - loop { - // Important to use try_io so that Tokio can clear the socket's readiness - // flag - let res = socket.try_io(Interest::READABLE, || { - let fd = socket.as_fd(); - recv_msg( - fd, - read_buf, - msg_flags.unwrap_or(MsgFlags::empty()), - store_cmsg_settings, - ) - .map_err(Into::into) - }); - - match res { - Err(e) if e.kind() == ErrorKind::WouldBlock => - socket.readable().await?, - _ => return res, - } - } -} - -#[cfg(not(target_os = "linux"))] pub async fn send_to( - socket: &UdpSocket, client_addr: SocketAddr, + socket: &UdpSocket, send_buf: &[u8], sendmsg_settings: SendMsgSettings, ) -> Result { - socket.send_to(send_buf, client_addr).await + std::future::poll_fn(|mut cx| { + poll_send_to(socket, &mut cx, send_buf, sendmsg_settings) + }) + .await } -#[cfg(not(target_os = "linux"))] +#[cfg(target_os = "linux")] +pub fn poll_recv_from( + socket: &UdpSocket, ctx: &mut Context<'_>, recv_buf: &mut [u8], + recvmsg_settings: &mut RecvMsgSettings, +) -> Poll> { + match socket.poll_recv_ready(ctx) { + Poll::Ready(Ok(())) => { + // Important to use try_io so that Tokio can clear the socket's + // readiness flag + match socket.try_io(Interest::READABLE, || { + let fd = socket.as_fd(); + recv_msg(fd, recv_buf, recvmsg_settings).map_err(Into::into) + }) { + Ok(n) => Poll::Ready(Ok(n)), + Err(e) if e.kind() == ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } + }, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } +} + +#[cfg(target_os = "linux")] pub async fn recv_from( - socket: &UdpSocket, read_buf: &mut [u8], + socket: &UdpSocket, recv_buf: &mut [u8], + recvmsg_settings: &mut RecvMsgSettings, ) -> Result { - let recv = socket.recv(read_buf).await?; - - Ok(RecvData::from_bytes(bytes)) + std::future::poll_fn(|mut ctx| { + poll_recv_from(socket, &mut ctx, recv_buf, recvmsg_settings) + }) + .await }