diff --git a/http-body-util/src/channel.rs b/http-body-util/src/channel.rs index fc3220f..c7548b2 100644 --- a/http-body-util/src/channel.rs +++ b/http-body-util/src/channel.rs @@ -9,12 +9,16 @@ use std::{ use bytes::Buf; use http::HeaderMap; use http_body::{Body, Frame}; -use tokio::sync::mpsc; - -/// A body backed by a channel. -pub struct Channel { - rx_frame: mpsc::Receiver>, - rx_error: mpsc::Receiver, +use pin_project_lite::pin_project; +use tokio::sync::{mpsc, oneshot}; + +pin_project! { + /// A body backed by a channel. + pub struct Channel { + rx_frame: mpsc::Receiver>, + #[pin] + rx_error: oneshot::Receiver, + } } impl Channel { @@ -25,7 +29,7 @@ impl Channel { /// provided buffer capacity must be at least 1. pub fn new(buffer: usize) -> (Sender, Self) { let (tx_frame, rx_frame) = mpsc::channel(buffer); - let (tx_error, rx_error) = mpsc::channel(1); + let (tx_error, rx_error) = oneshot::channel(); (Sender { tx_frame, tx_error }, Self { rx_frame, rx_error }) } } @@ -38,16 +42,19 @@ where type Error = E; fn poll_frame( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { - match self.rx_frame.poll_recv(cx) { + let this = self.project(); + + match this.rx_frame.poll_recv(cx) { Poll::Ready(frame) => return Poll::Ready(frame.map(Ok)), Poll::Pending => {} } - match self.rx_error.poll_recv(cx) { - Poll::Ready(err) => return Poll::Ready(err.map(Err)), + use core::future::Future; + match this.rx_error.poll(cx) { + Poll::Ready(err) => return Poll::Ready(err.ok().map(Err)), Poll::Pending => {} } @@ -55,7 +62,7 @@ where } } -impl std::fmt::Debug for Channel { +impl std::fmt::Debug for Channel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Channel") .field("rx_frame", &self.rx_frame) @@ -67,7 +74,7 @@ impl std::fmt::Debug for Channel { /// A sender half created through [`Channel::new`]. pub struct Sender { tx_frame: mpsc::Sender>, - tx_error: mpsc::Sender, + tx_error: oneshot::Sender, } impl Sender { @@ -88,24 +95,11 @@ impl Sender { /// Aborts the body in an abnormal fashion. pub fn abort(self, error: E) { - match self.tx_error.try_send(error) { - Ok(_) => {} - Err(err) => { - match err { - mpsc::error::TrySendError::Full(_) => { - // Channel::new creates the error channel with space for 1 message and we - // only send once because this method consumes `self`. So the receiver - // can't be full. - unreachable!("error receiver should never be full") - } - mpsc::error::TrySendError::Closed(_) => {} - } - } - } + self.tx_error.send(error).ok(); } } -impl std::fmt::Debug for Sender { +impl std::fmt::Debug for Sender { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Sender") .field("tx_frame", &self.tx_frame)