diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 628db4bf81..b5eefb9b88 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -177,7 +177,7 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@nightly + uses: dtolnay/rust-toolchain@stable - name: Install cbindgen uses: taiki-e/cache-cargo-install-action@v1 diff --git a/CHANGELOG.md b/CHANGELOG.md index fcf2e98be7..815bcec794 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,40 @@ +### v1.0.0-rc.4 (2023-07-10) + + +#### Bug Fixes + +* **http1:** + * http1 server graceful shutdown fix (#3261) ([f4b51300](https://github.com/hyperium/hyper/commit/f4b513009d81083081d1c60c1981847bbb17dd5d)) + * send error on Incoming body when connection errors (#3256) ([52f19259](https://github.com/hyperium/hyper/commit/52f192593fb9ebcf6d3894e0c85cbf710da4decd), closes [#3253](https://github.com/hyperium/hyper/issues/3253)) + * properly end chunked bodies when it was known to be empty (#3254) ([fec64cf0](https://github.com/hyperium/hyper/commit/fec64cf0abdc678e30ca5f1b310c5118b2e01999), closes [#3252](https://github.com/hyperium/hyper/issues/3252)) + + +#### Features + +* **client:** Make clients able to use non-Send executor (#3184) ([d977f209](https://github.com/hyperium/hyper/commit/d977f209bc6068d8f878b22803fc42d90c887fcc), closes [#3017](https://github.com/hyperium/hyper/issues/3017)) +* **rt:** + * replace IO traits with hyper::rt ones (#3230) ([f9f65b7a](https://github.com/hyperium/hyper/commit/f9f65b7aa67fa3ec0267fe015945973726285bc2), closes [#3110](https://github.com/hyperium/hyper/issues/3110)) + * add downcast on `Sleep` trait (#3125) ([d92d3917](https://github.com/hyperium/hyper/commit/d92d3917d950e4c61c37c2170f3ce273d2a0f7d1), closes [#3027](https://github.com/hyperium/hyper/issues/3027)) +* **service:** change Service::call to take &self (#3223) ([d894439e](https://github.com/hyperium/hyper/commit/d894439e009aa75103f6382a7ba98fb17da72f02), closes [#3040](https://github.com/hyperium/hyper/issues/3040)) + + +#### Breaking Changes + +* Any IO transport type provided must not implement `hyper::rt::{Read, Write}` instead of + `tokio::io` traits. You can grab a helper type from `hyper-util` to wrap Tokio types, or implement the traits yourself, + if it's a custom type. + ([f9f65b7a](https://github.com/hyperium/hyper/commit/f9f65b7aa67fa3ec0267fe015945973726285bc2)) +* `client::conn::http2` types now use another generic for an `Executor`. + Code that names `Connection` needs to include the additional generic parameter. + ([d977f209](https://github.com/hyperium/hyper/commit/d977f209bc6068d8f878b22803fc42d90c887fcc)) +* The Service::call function no longer takes a mutable reference to self. + The FnMut trait bound on the service::util::service_fn function and the trait bound + on the impl for the ServiceFn struct were changed from FnMut to Fn. + + ([d894439e](https://github.com/hyperium/hyper/commit/d894439e009aa75103f6382a7ba98fb17da72f02)) + + + ### v1.0.0-rc.3 (2023-02-23) diff --git a/Cargo.toml b/Cargo.toml index 7665e97e88..1ece598be8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyper" -version = "1.0.0-rc.3" +version = "1.0.0-rc.4" description = "A fast and correct HTTP library." readme = "README.md" homepage = "https://hyper.rs" @@ -25,7 +25,7 @@ futures-channel = "0.3" futures-util = { version = "0.3", default-features = false } http = "0.2" http-body = "=1.0.0-rc.2" -http-body-util = { version = "=0.1.0-rc.2", optional = true } +http-body-util = { version = "=0.1.0-rc.3", optional = true } httpdate = "1.0" httparse = "1.8" h2 = { version = "0.3.9", optional = true } @@ -41,7 +41,7 @@ libc = { version = "0.2", optional = true } [dev-dependencies] futures-util = { version = "0.3", default-features = false, features = ["alloc"] } -http-body-util = "=0.1.0-rc.2" +http-body-util = "=0.1.0-rc.3" matches = "0.1" num_cpus = "1.0" pretty_env_logger = "0.4" diff --git a/benches/end_to_end.rs b/benches/end_to_end.rs index 89c3caf4e2..3558e5c611 100644 --- a/benches/end_to_end.rs +++ b/benches/end_to_end.rs @@ -4,8 +4,7 @@ extern crate test; mod support; -// TODO: Reimplement Opts::bench using hyper::server::conn and hyper::client::conn -// (instead of Server and HttpClient). +// TODO: Reimplement parallel for HTTP/1 use std::convert::Infallible; use std::net::SocketAddr; @@ -315,7 +314,8 @@ impl Opts { let mut client = rt.block_on(async { if self.http2 { - let io = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let tcp = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let io = support::TokioIo::new(tcp); let (tx, conn) = hyper::client::conn::http2::Builder::new(support::TokioExecutor) .initial_stream_window_size(self.http2_stream_window) .initial_connection_window_size(self.http2_conn_window) @@ -328,7 +328,8 @@ impl Opts { } else if self.parallel_cnt > 1 { todo!("http/1 parallel >1"); } else { - let io = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let tcp = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let io = support::TokioIo::new(tcp); let (tx, conn) = hyper::client::conn::http1::Builder::new() .handshake(io) .await @@ -414,6 +415,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr { let opts = opts.clone(); rt.spawn(async move { while let Ok((sock, _)) = listener.accept().await { + let io = support::TokioIo::new(sock); if opts.http2 { tokio::spawn( hyper::server::conn::http2::Builder::new(support::TokioExecutor) @@ -421,7 +423,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr { .initial_connection_window_size(opts.http2_conn_window) .adaptive_window(opts.http2_adaptive_window) .serve_connection( - sock, + io, service_fn(move |req: Request| async move { let mut req_body = req.into_body(); while let Some(_chunk) = req_body.frame().await {} @@ -433,7 +435,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr { ); } else { tokio::spawn(hyper::server::conn::http1::Builder::new().serve_connection( - sock, + io, service_fn(move |req: Request| async move { let mut req_body = req.into_body(); while let Some(_chunk) = req_body.frame().await {} diff --git a/benches/pipeline.rs b/benches/pipeline.rs index a60100fa51..b79232de9b 100644 --- a/benches/pipeline.rs +++ b/benches/pipeline.rs @@ -3,6 +3,8 @@ extern crate test; +mod support; + use std::convert::Infallible; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; @@ -40,11 +42,12 @@ fn hello_world_16(b: &mut test::Bencher) { rt.spawn(async move { loop { let (stream, _addr) = listener.accept().await.expect("accept"); + let io = support::TokioIo::new(stream); http1::Builder::new() .pipeline_flush(true) .serve_connection( - stream, + io, service_fn(|_| async { Ok::<_, Infallible>(Response::new(Full::new(Bytes::from( "Hello, World!", diff --git a/benches/server.rs b/benches/server.rs index 17eefa0694..c5424105a8 100644 --- a/benches/server.rs +++ b/benches/server.rs @@ -3,6 +3,8 @@ extern crate test; +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::sync::mpsc; @@ -38,10 +40,11 @@ macro_rules! bench_server { rt.spawn(async move { loop { let (stream, _) = listener.accept().await.expect("accept"); + let io = support::TokioIo::new(stream); http1::Builder::new() .serve_connection( - stream, + io, service_fn(|_| async { Ok::<_, hyper::Error>( Response::builder() diff --git a/benches/support/mod.rs b/benches/support/mod.rs index 48e8048e8b..85cb67fd33 100644 --- a/benches/support/mod.rs +++ b/benches/support/mod.rs @@ -1,2 +1,2 @@ mod tokiort; -pub use tokiort::{TokioExecutor, TokioTimer}; +pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; diff --git a/benches/support/tokiort.rs b/benches/support/tokiort.rs index 67ae3a91aa..9a16e0ebad 100644 --- a/benches/support/tokiort.rs +++ b/benches/support/tokiort.rs @@ -41,6 +41,12 @@ impl Timer for TokioTimer { inner: tokio::time::sleep_until(deadline.into()), }) } + + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + if let Some(sleep) = sleep.as_mut().downcast_mut_pin::() { + sleep.reset(new_deadline.into()) + } + } } struct TokioTimeout { @@ -75,7 +81,156 @@ impl Future for TokioSleep { } } -// Use HasSleep to get tokio::time::Sleep to implement Unpin. -// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html - impl Sleep for TokioSleep {} + +impl TokioSleep { + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + self.project().inner.as_mut().reset(deadline.into()); + } +} + +pin_project! { + #[derive(Debug)] + pub struct TokioIo { + #[pin] + inner: T, + } +} + +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } + } + + pub fn inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl tokio::io::AsyncRead for TokioIo +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for TokioIo +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/examples/client.rs b/examples/client.rs index ffcc026719..046f59de02 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,6 +8,10 @@ use hyper::Request; use tokio::io::{self, AsyncWriteExt as _}; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; @@ -40,8 +44,9 @@ async fn fetch_url(url: hyper::Uri) -> Result<()> { let port = url.port_u16().unwrap_or(80); let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/client_json.rs b/examples/client_json.rs index 4ba6787a6e..6a6753528c 100644 --- a/examples/client_json.rs +++ b/examples/client_json.rs @@ -7,6 +7,10 @@ use hyper::{body::Buf, Request}; use serde::Deserialize; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; @@ -29,8 +33,9 @@ async fn fetch_json(url: hyper::Uri) -> Result> { let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/echo.rs b/examples/echo.rs index 7d3478a666..60d03b368d 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -10,6 +10,10 @@ use hyper::service::service_fn; use hyper::{body::Body, Method, Request, Response, StatusCode}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + /// This is our service handler. It receives a Request, routes on its /// path, and returns a Future of a Response. async fn echo( @@ -92,10 +96,11 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(echo)) + .serve_connection(io, service_fn(echo)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/gateway.rs b/examples/gateway.rs index 907f2fdba2..e0e3e053d0 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -4,6 +4,10 @@ use hyper::{server::conn::http1, service::service_fn}; use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + #[tokio::main] async fn main() -> Result<(), Box> { pretty_env_logger::init(); @@ -20,6 +24,7 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // This is the `Service` that will handle the connection. // `service_fn` is a helper to convert a function that @@ -42,9 +47,9 @@ async fn main() -> Result<(), Box> { async move { let client_stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(client_stream); - let (mut sender, conn) = - hyper::client::conn::http1::handshake(client_stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); @@ -56,10 +61,7 @@ async fn main() -> Result<(), Box> { }); tokio::task::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Failed to serve the connection: {:?}", err); } }); diff --git a/examples/hello.rs b/examples/hello.rs index a11199adb8..d9d6b8c4c7 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -10,6 +10,10 @@ use hyper::service::service_fn; use hyper::{Request, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // An async function that consumes a request, does nothing with it and returns a // response. async fn hello(_: Request) -> Result>, Infallible> { @@ -35,7 +39,10 @@ pub async fn main() -> Result<(), Box> { // has work to do. In this case, a connection arrives on the port we are listening on and // the task is woken up, at which point the task is then put back on a thread, and is // driven forward by the runtime, eventually yielding a TCP stream. - let (stream, _) = listener.accept().await?; + let (tcp, _) = listener.accept().await?; + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(tcp); // Spin up a new task in Tokio so we can continue to listen for new TCP connection on the // current task without waiting for the processing of the HTTP1 connection we just received @@ -44,7 +51,7 @@ pub async fn main() -> Result<(), Box> { // Handle the connection from the client using HTTP1 and pass any // HTTP requests received on that connection to the `hello` function if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(hello)) + .serve_connection(io, service_fn(hello)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/http_proxy.rs b/examples/http_proxy.rs index 0b4a6818b8..c36cc23778 100644 --- a/examples/http_proxy.rs +++ b/examples/http_proxy.rs @@ -12,6 +12,10 @@ use hyper::{Method, Request, Response}; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // To try this example: // 1. cargo run --example http_proxy // 2. config http_proxy in command line @@ -28,12 +32,13 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection(stream, service_fn(proxy)) + .serve_connection(io, service_fn(proxy)) .with_upgrades() .await { @@ -88,11 +93,12 @@ async fn proxy( let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); let (mut sender, conn) = Builder::new() .preserve_header_case(true) .title_case_headers(true) - .handshake(stream) + .handshake(io) .await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -123,9 +129,10 @@ fn full>(chunk: T) -> BoxBody { // Create a TCP connection to host:port, build a tunnel between the connection and // the upgraded connection -async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> { +async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { // Connect to remote server let mut server = TcpStream::connect(addr).await?; + let mut upgraded = TokioIo::new(upgraded); // Proxying data let (from_client, from_server) = diff --git a/examples/multi_server.rs b/examples/multi_server.rs index 5eb520dbdb..51e6c39ca7 100644 --- a/examples/multi_server.rs +++ b/examples/multi_server.rs @@ -11,6 +11,10 @@ use hyper::service::service_fn; use hyper::{Request, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX1: &[u8] = b"The 1st service!"; static INDEX2: &[u8] = b"The 2nd service!"; @@ -33,10 +37,11 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::bind(addr1).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(index1)) + .serve_connection(io, service_fn(index1)) .await { println!("Error serving connection: {:?}", err); @@ -49,10 +54,11 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::bind(addr2).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(index2)) + .serve_connection(io, service_fn(index2)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/params.rs b/examples/params.rs index a902867f2e..3ba39326a1 100644 --- a/examples/params.rs +++ b/examples/params.rs @@ -13,6 +13,10 @@ use std::convert::Infallible; use std::net::SocketAddr; use url::form_urlencoded; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX: &[u8] = b"
Name:
Number:
"; static MISSING: &[u8] = b"Missing field"; static NOTNUMERIC: &[u8] = b"Number field is not numeric"; @@ -124,10 +128,11 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(param_example)) + .serve_connection(io, service_fn(param_example)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/send_file.rs b/examples/send_file.rs index a4514eb52b..ec489ec34f 100644 --- a/examples/send_file.rs +++ b/examples/send_file.rs @@ -10,6 +10,10 @@ use http_body_util::Full; use hyper::service::service_fn; use hyper::{Method, Request, Response, Result, StatusCode}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX: &str = "examples/send_file_index.html"; static NOTFOUND: &[u8] = b"Not Found"; @@ -24,10 +28,11 @@ async fn main() -> std::result::Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(response_examples)) + .serve_connection(io, service_fn(response_examples)) .await { println!("Failed to serve connection: {:?}", err); diff --git a/examples/service_struct_impl.rs b/examples/service_struct_impl.rs index 22cc2407dd..fc0f79356c 100644 --- a/examples/service_struct_impl.rs +++ b/examples/service_struct_impl.rs @@ -10,6 +10,10 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Mutex; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + type Counter = i32; #[tokio::main] @@ -21,11 +25,12 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .serve_connection( - stream, + io, Svc { counter: Mutex::new(81818), }, diff --git a/examples/single_threaded.rs b/examples/single_threaded.rs index ee109d54fa..40e8d942b2 100644 --- a/examples/single_threaded.rs +++ b/examples/single_threaded.rs @@ -1,17 +1,36 @@ #![deny(warnings)] - +/// This example shows how to use hyper with a single-threaded runtime. +/// This example exists also to test if the code compiles when `Body` is not `Send`. +/// +/// This Example includes HTTP/1 and HTTP/2 server and client. +/// +/// In HTTP/1 it is possible to use a `!Send` `Body`type. +/// In HTTP/2 it is possible to use a `!Send` `Body` and `IO` type. +/// +/// The `Body` and `IOTypeNotSend` structs in this example are `!Send` +/// +/// For HTTP/2 this only works if the `Executor` trait is implemented without the `Send` bound. +use http_body_util::BodyExt; use hyper::server::conn::http2; use std::cell::Cell; use std::net::SocketAddr; use std::rc::Rc; +use tokio::io::{self, AsyncWriteExt}; use tokio::net::TcpListener; use hyper::body::{Body as HttpBody, Bytes, Frame}; use hyper::service::service_fn; +use hyper::Request; use hyper::{Error, Response}; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; +use std::thread; +use tokio::net::TcpStream; + +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; struct Body { // Our Body type is !Send and !Sync: @@ -40,30 +59,181 @@ impl HttpBody for Body { } } -fn main() -> Result<(), Box> { +fn main() { pretty_env_logger::init(); - // Configure a runtime that runs everything on the current thread - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("build runtime"); + let server_http2 = thread::spawn(move || { + // Configure a runtime for the server that runs everything on the current thread + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("build runtime"); + + // Combine it with a `LocalSet, which means it can spawn !Send futures... + let local = tokio::task::LocalSet::new(); + local.block_on(&rt, http2_server()).unwrap(); + }); + + let client_http2 = thread::spawn(move || { + // Configure a runtime for the client that runs everything on the current thread + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("build runtime"); + + // Combine it with a `LocalSet, which means it can spawn !Send futures... + let local = tokio::task::LocalSet::new(); + local + .block_on( + &rt, + http2_client("http://localhost:3000".parse::().unwrap()), + ) + .unwrap(); + }); + + let server_http1 = thread::spawn(move || { + // Configure a runtime for the server that runs everything on the current thread + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("build runtime"); - // Combine it with a `LocalSet, which means it can spawn !Send futures... - let local = tokio::task::LocalSet::new(); - local.block_on(&rt, run()) + // Combine it with a `LocalSet, which means it can spawn !Send futures... + let local = tokio::task::LocalSet::new(); + local.block_on(&rt, http1_server()).unwrap(); + }); + + let client_http1 = thread::spawn(move || { + // Configure a runtime for the client that runs everything on the current thread + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("build runtime"); + + // Combine it with a `LocalSet, which means it can spawn !Send futures... + let local = tokio::task::LocalSet::new(); + local + .block_on( + &rt, + http1_client("http://localhost:3001".parse::().unwrap()), + ) + .unwrap(); + }); + + server_http2.join().unwrap(); + client_http2.join().unwrap(); + + server_http1.join().unwrap(); + client_http1.join().unwrap(); } -async fn run() -> Result<(), Box> { - let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); +async fn http1_server() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 3001)); + + let listener = TcpListener::bind(addr).await?; + + // For each connection, clone the counter to use in our service... + let counter = Rc::new(Cell::new(0)); + + loop { + let (stream, _) = listener.accept().await?; + + let io = TokioIo::new(stream); + + let cnt = counter.clone(); + + let service = service_fn(move |_| { + let prev = cnt.get(); + cnt.set(prev + 1); + let value = cnt.get(); + async move { Ok::<_, Error>(Response::new(Body::from(format!("Request #{}", value)))) } + }); + + tokio::task::spawn_local(async move { + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} + +async fn http1_client(url: hyper::Uri) -> Result<(), Box> { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(80); + let addr = format!("{}:{}", host, port); + let stream = TcpStream::connect(addr).await?; + + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + + tokio::task::spawn_local(async move { + if let Err(err) = conn.await { + let mut stdout = io::stdout(); + stdout + .write_all(format!("Connection failed: {:?}", err).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + } + }); + + let authority = url.authority().unwrap().clone(); + + // Make 4 requests + for _ in 0..4 { + let req = Request::builder() + .uri(url.clone()) + .header(hyper::header::HOST, authority.as_str()) + .body(Body::from("test".to_string()))?; + let mut res = sender.send_request(req).await?; + + let mut stdout = io::stdout(); + stdout + .write_all(format!("Response: {}\n", res.status()).as_bytes()) + .await + .unwrap(); + stdout + .write_all(format!("Headers: {:#?}\n", res.headers()).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + + // Print the response body + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + stdout.write_all(&chunk).await.unwrap(); + } + } + stdout.write_all(b"\n-----------------\n").await.unwrap(); + stdout.flush().await.unwrap(); + } + Ok(()) +} + +async fn http2_server() -> Result<(), Box> { + let mut stdout = io::stdout(); + + let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); // Using a !Send request counter is fine on 1 thread... let counter = Rc::new(Cell::new(0)); let listener = TcpListener::bind(addr).await?; - println!("Listening on http://{}", addr); + + stdout + .write_all(format!("Listening on http://{}", addr).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + loop { let (stream, _) = listener.accept().await?; + let io = IOTypeNotSend::new(TokioIo::new(stream)); // For each connection, clone the counter to use in our service... let cnt = counter.clone(); @@ -77,15 +247,76 @@ async fn run() -> Result<(), Box> { tokio::task::spawn_local(async move { if let Err(err) = http2::Builder::new(LocalExec) - .serve_connection(stream, service) + .serve_connection(io, service) .await { - println!("Error serving connection: {:?}", err); + let mut stdout = io::stdout(); + stdout + .write_all(format!("Error serving connection: {:?}", err).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); } }); } } +async fn http2_client(url: hyper::Uri) -> Result<(), Box> { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(80); + let addr = format!("{}:{}", host, port); + let stream = TcpStream::connect(addr).await?; + + let stream = IOTypeNotSend::new(TokioIo::new(stream)); + + let (mut sender, conn) = hyper::client::conn::http2::handshake(LocalExec, stream).await?; + + tokio::task::spawn_local(async move { + if let Err(err) = conn.await { + let mut stdout = io::stdout(); + stdout + .write_all(format!("Connection failed: {:?}", err).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + } + }); + + let authority = url.authority().unwrap().clone(); + + // Make 4 requests + for _ in 0..4 { + let req = Request::builder() + .uri(url.clone()) + .header(hyper::header::HOST, authority.as_str()) + .body(Body::from("test".to_string()))?; + + let mut res = sender.send_request(req).await?; + + let mut stdout = io::stdout(); + stdout + .write_all(format!("Response: {}\n", res.status()).as_bytes()) + .await + .unwrap(); + stdout + .write_all(format!("Headers: {:#?}\n", res.headers()).as_bytes()) + .await + .unwrap(); + stdout.flush().await.unwrap(); + + // Print the response body + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + stdout.write_all(&chunk).await.unwrap(); + } + } + stdout.write_all(b"\n-----------------\n").await.unwrap(); + stdout.flush().await.unwrap(); + } + Ok(()) +} + // NOTE: This part is only needed for HTTP/2. HTTP/1 doesn't need an executor. // // Since the Server needs to spawn some background tasks, we needed @@ -102,3 +333,51 @@ where tokio::task::spawn_local(fut); } } + +struct IOTypeNotSend { + _marker: PhantomData<*const ()>, + stream: TokioIo, +} + +impl IOTypeNotSend { + fn new(stream: TokioIo) -> Self { + Self { + _marker: PhantomData, + stream, + } + } +} + +impl hyper::rt::Write for IOTypeNotSend { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +impl hyper::rt::Read for IOTypeNotSend { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} diff --git a/examples/state.rs b/examples/state.rs index 7d060efe1d..5263efdadc 100644 --- a/examples/state.rs +++ b/examples/state.rs @@ -12,6 +12,10 @@ use hyper::{server::conn::http1, service::service_fn}; use hyper::{Error, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + #[tokio::main] async fn main() -> Result<(), Box> { pretty_env_logger::init(); @@ -26,6 +30,7 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // Each connection could send multiple requests, so // the `Service` needs a clone to handle later requests. @@ -46,10 +51,7 @@ async fn main() -> Result<(), Box> { } }); - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Error serving connection: {:?}", err); } } diff --git a/examples/upgrades.rs b/examples/upgrades.rs index 92a80d7567..f9754e5d49 100644 --- a/examples/upgrades.rs +++ b/examples/upgrades.rs @@ -16,11 +16,16 @@ use hyper::service::service_fn; use hyper::upgrade::Upgraded; use hyper::{Request, Response, StatusCode}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; /// Handle server-side I/O after HTTP upgraded. -async fn server_upgraded_io(mut upgraded: Upgraded) -> Result<()> { +async fn server_upgraded_io(upgraded: Upgraded) -> Result<()> { + let mut upgraded = TokioIo::new(upgraded); // we have an upgraded connection that we can read and // write on directly. // @@ -75,7 +80,8 @@ async fn server_upgrade(mut req: Request) -> Result Result<()> { +async fn client_upgraded_io(upgraded: Upgraded) -> Result<()> { + let mut upgraded = TokioIo::new(upgraded); // We've gotten an upgraded connection that we can read // and write directly on. Let's start out 'foobar' protocol. upgraded.write_all(b"foo=bar").await?; @@ -97,7 +103,8 @@ async fn client_upgrade_request(addr: SocketAddr) -> Result<()> { .unwrap(); let stream = TcpStream::connect(addr).await?; - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -146,10 +153,11 @@ async fn main() { tokio::select! { res = listener.accept() => { let (stream, _) = res.expect("Failed to accept"); + let io = TokioIo::new(stream); let mut rx = rx.clone(); tokio::task::spawn(async move { - let conn = http1::Builder::new().serve_connection(stream, service_fn(server_upgrade)); + let conn = http1::Builder::new().serve_connection(io, service_fn(server_upgrade)); // Don't forget to enable upgrades on the connection. let mut conn = conn.with_upgrades(); diff --git a/examples/web_api.rs b/examples/web_api.rs index 79834a0acd..91d9e9b72f 100644 --- a/examples/web_api.rs +++ b/examples/web_api.rs @@ -9,6 +9,10 @@ use hyper::service::service_fn; use hyper::{body::Incoming as IncomingBody, header, Method, Request, Response, StatusCode}; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + type GenericError = Box; type Result = std::result::Result; type BoxBody = http_body_util::combinators::BoxBody; @@ -30,8 +34,9 @@ async fn client_request_response() -> Result> { let host = req.uri().host().expect("uri has no host"); let port = req.uri().port_u16().expect("uri has no port"); let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -109,14 +114,12 @@ async fn main() -> Result<()> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { let service = service_fn(move |req| response_examples(req)); - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Failed to serve connection: {:?}", err); } }); diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index cecae92212..2034f0f2a6 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -3,10 +3,10 @@ use std::error::Error as StdError; use std::fmt; +use crate::rt::{Read, Write}; use bytes::Bytes; use http::{Request, Response}; use httparse::ParserConfig; -use tokio::io::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; @@ -49,7 +49,7 @@ pub struct Parts { #[must_use = "futures do nothing unless polled"] pub struct Connection where - T: AsyncRead + AsyncWrite + Send + 'static, + T: Read + Write + Send + 'static, B: Body + 'static, { inner: Option>, @@ -57,7 +57,7 @@ where impl Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + T: Read + Write + Send + Unpin + 'static, B: Body + 'static, B::Error: Into>, { @@ -114,7 +114,7 @@ pub struct Builder { /// See [`client::conn`](crate::client::conn) for more. pub async fn handshake(io: T) -> crate::Result<(SendRequest, Connection)> where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, @@ -238,7 +238,7 @@ impl fmt::Debug for SendRequest { impl fmt::Debug for Connection where - T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static, + T: Read + Write + fmt::Debug + Send + 'static, B: Body + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -248,8 +248,8 @@ where impl Future for Connection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, - B: Body + Send + 'static, + T: Read + Write + Unpin + Send + 'static, + B: Body + 'static, B::Data: Send, B::Error: Into>, { @@ -470,7 +470,7 @@ impl Builder { io: T, ) -> impl Future, Connection)>> where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, diff --git a/src/client/conn/http2.rs b/src/client/conn/http2.rs index a4cdc22f71..e9686347b3 100644 --- a/src/client/conn/http2.rs +++ b/src/client/conn/http2.rs @@ -1,23 +1,21 @@ //! HTTP/2 client connections -use std::error::Error as StdError; +use std::error::Error; use std::fmt; use std::marker::PhantomData; use std::sync::Arc; use std::time::Duration; +use crate::rt::{Read, Write}; use http::{Request, Response}; -use tokio::io::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::time::Time; -use crate::common::{ - exec::{BoxSendFuture, Exec}, - task, Future, Pin, Poll, -}; +use crate::common::{task, Future, Pin, Poll}; use crate::proto; -use crate::rt::{Executor, Timer}; +use crate::rt::bounds::ExecutorClient; +use crate::rt::Timer; /// The sender side of an established connection. pub struct SendRequest { @@ -37,20 +35,22 @@ impl Clone for SendRequest { /// In most cases, this should just be spawned into an executor, so that it /// can process incoming and outgoing messages, notice hangups, and the like. #[must_use = "futures do nothing unless polled"] -pub struct Connection +pub struct Connection where - T: AsyncRead + AsyncWrite + Send + 'static, + T: Read + Write + 'static + Unpin, B: Body + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, { - inner: (PhantomData, proto::h2::ClientTask), + inner: (PhantomData, proto::h2::ClientTask), } /// A builder to configure an HTTP connection. /// /// After setting options, the builder is used to create a handshake future. #[derive(Clone, Debug)] -pub struct Builder { - pub(super) exec: Exec, +pub struct Builder { + pub(super) exec: Ex, pub(super) timer: Time, h2_builder: proto::h2::client::Config, } @@ -59,13 +59,16 @@ pub struct Builder { /// /// This is a shortcut for `Builder::new().handshake(io)`. /// See [`client::conn`](crate::client::conn) for more. -pub async fn handshake(exec: E, io: T) -> crate::Result<(SendRequest, Connection)> +pub async fn handshake( + exec: E, + io: T, +) -> crate::Result<(SendRequest, Connection)> where - E: Executor + Send + Sync + 'static, - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + E: ExecutorClient + Unpin + Clone, { Builder::new(exec).handshake(io).await } @@ -188,12 +191,13 @@ impl fmt::Debug for SendRequest { // ===== impl Connection -impl Connection +impl Connection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, - B: Body + Unpin + Send + 'static, + T: Read + Write + Unpin + 'static, + B: Body + Unpin + 'static, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + E: ExecutorClient + Unpin, { /// Returns whether the [extended CONNECT protocol][1] is enabled or not. /// @@ -209,22 +213,26 @@ where } } -impl fmt::Debug for Connection +impl fmt::Debug for Connection where - T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static, + T: Read + Write + fmt::Debug + 'static + Unpin, B: Body + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Connection").finish() } } -impl Future for Connection +impl Future for Connection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, - B: Body + Send + 'static, + T: Read + Write + Unpin + 'static, + B: Body + 'static + Unpin, B::Data: Send, - B::Error: Into>, + E: Unpin, + B::Error: Into>, + E: ExecutorClient + 'static + Send + Sync + Unpin, { type Output = crate::Result<()>; @@ -239,22 +247,22 @@ where // ===== impl Builder -impl Builder { +impl Builder +where + Ex: Clone, +{ /// Creates a new connection builder. #[inline] - pub fn new(exec: E) -> Builder - where - E: Executor + Send + Sync + 'static, - { + pub fn new(exec: Ex) -> Builder { Builder { - exec: Exec::new(exec), + exec, timer: Time::Empty, h2_builder: Default::default(), } } /// Provide a timer to execute background HTTP2 tasks. - pub fn timer(&mut self, timer: M) -> &mut Builder + pub fn timer(&mut self, timer: M) -> &mut Builder where M: Timer + Send + Sync + 'static, { @@ -269,7 +277,7 @@ impl Builder { /// /// If not set, hyper will use a default. /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + /// [spec]: https://httpwg.org/specs/rfc9113.html#SETTINGS_INITIAL_WINDOW_SIZE pub fn initial_stream_window_size(&mut self, sz: impl Into>) -> &mut Self { if let Some(sz) = sz.into() { self.h2_builder.adaptive_window = false; @@ -388,12 +396,13 @@ impl Builder { pub fn handshake( &self, io: T, - ) -> impl Future, Connection)>> + ) -> impl Future, Connection)>> where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + Ex: ExecutorClient + Unpin, { let opts = self.clone(); diff --git a/src/client/conn/mod.rs b/src/client/conn/mod.rs index a70d86e5d3..eda436a8b8 100644 --- a/src/client/conn/mod.rs +++ b/src/client/conn/mod.rs @@ -9,7 +9,9 @@ //! higher-level [Client](super) API. //! //! ## Example -//! A simple example that uses the `SendRequest` struct to talk HTTP over a Tokio TCP stream +//! +//! A simple example that uses the `SendRequest` struct to talk HTTP over some TCP stream. +//! //! ```no_run //! # #[cfg(all(feature = "client", feature = "http1"))] //! # mod rt { @@ -17,38 +19,38 @@ //! use http::{Request, StatusCode}; //! use http_body_util::Empty; //! use hyper::client::conn; -//! use tokio::net::TcpStream; -//! -//! #[tokio::main] -//! async fn main() -> Result<(), Box> { -//! let target_stream = TcpStream::connect("example.com:80").await?; -//! -//! let (mut request_sender, connection) = conn::http1::handshake(target_stream).await?; -//! -//! // spawn a task to poll the connection and drive the HTTP state -//! tokio::spawn(async move { -//! if let Err(e) = connection.await { -//! eprintln!("Error in connection: {}", e); -//! } -//! }); -//! -//! let request = Request::builder() -//! // We need to manually add the host header because SendRequest does not -//! .header("Host", "example.com") -//! .method("GET") -//! .body(Empty::::new())?; -//! let response = request_sender.send_request(request).await?; -//! assert!(response.status() == StatusCode::OK); -//! -//! let request = Request::builder() -//! .header("Host", "example.com") -//! .method("GET") -//! .body(Empty::::new())?; -//! let response = request_sender.send_request(request).await?; -//! assert!(response.status() == StatusCode::OK); -//! Ok(()) -//! } -//! +//! # use hyper::rt::{Read, Write}; +//! # async fn run(tcp: I) -> Result<(), Box> +//! # where +//! # I: Read + Write + Unpin + Send + 'static, +//! # { +//! let (mut request_sender, connection) = conn::http1::handshake(tcp).await?; +//! +//! // spawn a task to poll the connection and drive the HTTP state +//! tokio::spawn(async move { +//! if let Err(e) = connection.await { +//! eprintln!("Error in connection: {}", e); +//! } +//! }); +//! +//! let request = Request::builder() +//! // We need to manually add the host header because SendRequest does not +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Empty::::new())?; +//! +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! +//! let request = Request::builder() +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Empty::::new())?; +//! +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! # Ok(()) +//! # } //! # } //! ``` diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 3aef84012f..40cb554917 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -1,11 +1,18 @@ #[cfg(feature = "http2")] use std::future::Future; +use http::{Request, Response}; +use http_body::Body; +use pin_project_lite::pin_project; use tokio::sync::{mpsc, oneshot}; +use tracing::trace; +use crate::{ + body::Incoming, + common::{task, Poll}, +}; #[cfg(feature = "http2")] -use crate::common::Pin; -use crate::common::{task, Poll}; +use crate::{common::Pin, proto::h2::client::ResponseFutMap}; #[cfg(test)] pub(crate) type RetryPromise = oneshot::Receiver)>>; @@ -266,37 +273,57 @@ impl Callback { } } } +} - #[cfg(feature = "http2")] - pub(crate) async fn send_when( - self, - mut when: impl Future)>> + Unpin, - ) { - use futures_util::future; - use tracing::trace; - - let mut cb = Some(self); - - // "select" on this callback being canceled, and the future completing - future::poll_fn(move |cx| { - match Pin::new(&mut when).poll(cx) { - Poll::Ready(Ok(res)) => { - cb.take().expect("polled after complete").send(Ok(res)); - Poll::Ready(()) - } - Poll::Pending => { - // check if the callback is canceled - ready!(cb.as_mut().unwrap().poll_canceled(cx)); - trace!("send_when canceled"); - Poll::Ready(()) - } - Poll::Ready(Err(err)) => { - cb.take().expect("polled after complete").send(Err(err)); - Poll::Ready(()) - } +#[cfg(feature = "http2")] +pin_project! { + pub struct SendWhen + where + B: Body, + B: 'static, + { + #[pin] + pub(crate) when: ResponseFutMap, + #[pin] + pub(crate) call_back: Option, Response>>, + } +} + +#[cfg(feature = "http2")] +impl Future for SendWhen +where + B: Body + 'static, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + let mut call_back = this.call_back.take().expect("polled after complete"); + + match Pin::new(&mut this.when).poll(cx) { + Poll::Ready(Ok(res)) => { + call_back.send(Ok(res)); + Poll::Ready(()) } - }) - .await + Poll::Pending => { + // check if the callback is canceled + match call_back.poll_canceled(cx) { + Poll::Ready(v) => v, + Poll::Pending => { + // Move call_back back to struct before return + this.call_back.set(Some(call_back)); + return std::task::Poll::Pending; + } + }; + trace!("send_when canceled"); + Poll::Ready(()) + } + Poll::Ready(Err(err)) => { + call_back.send(Err(err)); + Poll::Ready(()) + } + } } } diff --git a/src/common/exec.rs b/src/common/exec.rs index ef006c9d84..69d19e9bb7 100644 --- a/src/common/exec.rs +++ b/src/common/exec.rs @@ -1,50 +1,14 @@ -use std::fmt; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; - -use crate::rt::Executor; - -pub(crate) type BoxSendFuture = Pin + Send>>; - -// Executor must be provided by the user -#[derive(Clone)] -pub(crate) struct Exec(Arc + Send + Sync>); - -// ===== impl Exec ===== - -impl Exec { - pub(crate) fn new(exec: E) -> Self - where - E: Executor + Send + Sync + 'static, - { - Self(Arc::new(exec)) - } - - pub(crate) fn execute(&self, fut: F) - where - F: Future + Send + 'static, - { - self.0.execute(Box::pin(fut)) - } -} - -impl fmt::Debug for Exec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Exec").finish() - } -} // If http2 is not enable, we just have a stub here, so that the trait bounds // that *would* have been needed are still checked. Why? // // Because enabling `http2` shouldn't suddenly add new trait bounds that cause // a compilation error. -#[cfg(not(feature = "http2"))] -#[allow(missing_debug_implementations)] + pub struct H2Stream(std::marker::PhantomData<(F, B)>); -#[cfg(not(feature = "http2"))] impl Future for H2Stream where F: Future, E>>, diff --git a/src/common/io/compat.rs b/src/common/io/compat.rs new file mode 100644 index 0000000000..3320e4ff44 --- /dev/null +++ b/src/common/io/compat.rs @@ -0,0 +1,150 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// This adapts from `hyper` IO traits to the ones in Tokio. +/// +/// This is currently used by `h2`, and by hyper internal unit tests. +#[derive(Debug)] +pub(crate) struct Compat(pub(crate) T); + +pub(crate) fn compat(io: T) -> Compat { + Compat(io) +} + +impl Compat { + fn p(self: Pin<&mut Self>) -> Pin<&mut T> { + // SAFETY: The simplest of projections. This is just + // a wrapper, we don't do anything that would undo the projection. + unsafe { self.map_unchecked_mut(|me| &mut me.0) } + } +} + +impl tokio::io::AsyncRead for Compat +where + T: crate::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let (new_init, new_filled) = unsafe { + let mut buf = crate::rt::ReadBuf::uninit(tbuf.inner_mut()); + buf.set_init(init); + buf.set_filled(filled); + + match crate::rt::Read::poll_read(self.p(), cx, buf.unfilled()) { + Poll::Ready(Ok(())) => (buf.init_len(), buf.len()), + other => return other, + } + }; + + let n_init = new_init - init; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(new_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for Compat +where + T: crate::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + crate::rt::Write::poll_write(self.p(), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + crate::rt::Write::poll_flush(self.p(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + crate::rt::Write::poll_shutdown(self.p(), cx) + } + + fn is_write_vectored(&self) -> bool { + crate::rt::Write::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + crate::rt::Write::poll_write_vectored(self.p(), cx, bufs) + } +} + +#[cfg(test)] +impl crate::rt::Read for Compat +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: crate::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.p(), cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +impl crate::rt::Write for Compat +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.p(), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.p(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.p(), cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.p(), cx, bufs) + } +} diff --git a/src/common/io/mod.rs b/src/common/io/mod.rs index 2e6d506153..6ad07bb771 100644 --- a/src/common/io/mod.rs +++ b/src/common/io/mod.rs @@ -1,3 +1,7 @@ +#[cfg(any(feature = "http2", test))] +mod compat; mod rewind; +#[cfg(any(feature = "http2", test))] +pub(crate) use self::compat::{compat, Compat}; pub(crate) use self::rewind::Rewind; diff --git a/src/common/io/rewind.rs b/src/common/io/rewind.rs index 5642d897d1..f6b6bab3c7 100644 --- a/src/common/io/rewind.rs +++ b/src/common/io/rewind.rs @@ -2,9 +2,9 @@ use std::marker::Unpin; use std::{cmp, io}; use bytes::{Buf, Bytes}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::common::{task, Pin, Poll}; +use crate::rt::{Read, ReadBufCursor, Write}; /// Combine a buffer with an IO, rewinding reads to use the buffer. #[derive(Debug)] @@ -44,14 +44,14 @@ impl Rewind { // } } -impl AsyncRead for Rewind +impl Read for Rewind where - T: AsyncRead + Unpin, + T: Read + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + mut buf: ReadBufCursor<'_>, ) -> Poll> { if let Some(mut prefix) = self.pre.take() { // If there are no remaining bytes, let the bytes get dropped. @@ -72,9 +72,9 @@ where } } -impl AsyncWrite for Rewind +impl Write for Rewind where - T: AsyncWrite + Unpin, + T: Write + Unpin, { fn poll_write( mut self: Pin<&mut Self>, @@ -109,6 +109,7 @@ where mod tests { // FIXME: re-implement tests with `async/await`, this import should // trigger a warning to remind us + use super::super::compat; use super::Rewind; use bytes::Bytes; use tokio::io::AsyncReadExt; @@ -120,14 +121,14 @@ mod tests { let mock = tokio_test::io::Builder::new().read(&underlying).build(); - let mut stream = Rewind::new(mock); + let mut stream = compat(Rewind::new(compat(mock))); // Read off some bytes, ensure we filled o1 let mut buf = [0; 2]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. - stream.rewind(Bytes::copy_from_slice(&buf[..])); + stream.0.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); @@ -143,13 +144,13 @@ mod tests { let mock = tokio_test::io::Builder::new().read(&underlying).build(); - let mut stream = Rewind::new(mock); + let mut stream = compat(Rewind::new(compat(mock))); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. - stream.rewind(Bytes::copy_from_slice(&buf[..])); + stream.0.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); diff --git a/src/common/mod.rs b/src/common/mod.rs index 67b2bbde59..2392851951 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -10,7 +10,7 @@ macro_rules! ready { pub(crate) mod buf; #[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] pub(crate) mod date; -#[cfg(any(feature = "http1", feature = "http2", feature = "server"))] +#[cfg(not(feature = "http2"))] pub(crate) mod exec; pub(crate) mod io; mod never; diff --git a/src/ffi/io.rs b/src/ffi/io.rs index bff666dbcf..1d198820a6 100644 --- a/src/ffi/io.rs +++ b/src/ffi/io.rs @@ -2,8 +2,8 @@ use std::ffi::c_void; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::rt::{Read, Write}; use libc::size_t; -use tokio::io::{AsyncRead, AsyncWrite}; use super::task::hyper_context; @@ -120,13 +120,13 @@ extern "C" fn write_noop( 0 } -impl AsyncRead for hyper_io { +impl Read for hyper_io { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, + mut buf: crate::rt::ReadBufCursor<'_>, ) -> Poll> { - let buf_ptr = unsafe { buf.unfilled_mut() }.as_mut_ptr() as *mut u8; + let buf_ptr = unsafe { buf.as_mut() }.as_mut_ptr() as *mut u8; let buf_len = buf.remaining(); match (self.read)(self.userdata, hyper_context::wrap(cx), buf_ptr, buf_len) { @@ -138,15 +138,14 @@ impl AsyncRead for hyper_io { ok => { // We have to trust that the user's read callback actually // filled in that many bytes... :( - unsafe { buf.assume_init(ok) }; - buf.advance(ok); + unsafe { buf.advance(ok) }; Poll::Ready(Ok(())) } } } } -impl AsyncWrite for hyper_io { +impl Write for hyper_io { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/src/ffi/task.rs b/src/ffi/task.rs index ef54fe408f..a973a7bab3 100644 --- a/src/ffi/task.rs +++ b/src/ffi/task.rs @@ -177,8 +177,12 @@ impl WeakExec { } } -impl crate::rt::Executor> for WeakExec { - fn execute(&self, fut: BoxFuture<()>) { +impl crate::rt::Executor for WeakExec +where + F: Future + Send + 'static, + F::Output: Send + Sync + AsTaskType, +{ + fn execute(&self, fut: F) { if let Some(exec) = self.0.upgrade() { exec.spawn(hyper_task::boxed(fut)); } diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index b7c619683c..e0d65bd2d4 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -4,11 +4,11 @@ use std::marker::PhantomData; #[cfg(feature = "server")] use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; use http::header::{HeaderValue, CONNECTION}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, error, trace}; use super::io::Buffered; @@ -25,7 +25,7 @@ use crate::rt::Sleep; const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// This handles a connection, which will have been established over an -/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple +/// `Read + Write` (like a socket), and will likely include multiple /// `Transaction`s over HTTP. /// /// The connection will determine when a message begins and ends as well as @@ -39,7 +39,7 @@ pub(crate) struct Conn { impl Conn where - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Buf, T: Http1Transaction, { @@ -175,6 +175,13 @@ where } } + #[cfg(feature = "server")] + pub(crate) fn has_initial_read_write_state(&self) -> bool { + matches!(self.state.reading, Reading::Init) + && matches!(self.state.writing, Writing::Init) + && self.io.read_buf().is_empty() + } + fn should_error_on_eof(&self) -> bool { // If we're idle, it's probably just the connection closing gracefully. T::should_error_on_parse_eof() && !self.state.is_idle() @@ -1037,12 +1044,13 @@ mod tests { #[bench] fn bench_read_head_short(b: &mut ::test::Bencher) { use super::*; + use crate::common::io::Compat; let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; let len = s.len(); b.bytes = len as u64; // an empty IO, we'll be skipping and using the read buffer anyways - let io = tokio_test::io::Builder::new().build(); + let io = Compat(tokio_test::io::Builder::new().build()); let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 4077b22062..47d9bbd081 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -428,9 +428,9 @@ impl StdError for IncompleteBody {} #[cfg(test)] mod tests { use super::*; + use crate::rt::{Read, ReadBuf}; use std::pin::Pin; use std::time::Duration; - use tokio::io::{AsyncRead, ReadBuf}; impl<'a> MemRead for &'a [u8] { fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll> { @@ -446,11 +446,11 @@ mod tests { } } - impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + impl<'a> MemRead for &'a mut (dyn Read + Unpin) { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { let mut v = vec![0; len]; let mut buf = ReadBuf::new(&mut v); - ready!(Pin::new(self).poll_read(cx, &mut buf)?); + ready!(Pin::new(self).poll_read(cx, buf.unfilled())?); Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled()))) } } @@ -629,7 +629,7 @@ mod tests { async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { let mut outs = Vec::new(); - let mut ins = if block_at == 0 { + let mut ins = crate::common::io::compat(if block_at == 0 { tokio_test::io::Builder::new() .wait(Duration::from_millis(10)) .read(content) @@ -640,9 +640,9 @@ mod tests { .wait(Duration::from_millis(10)) .read(&content[block_at..]) .build() - }; + }); - let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + let mut ins = &mut ins as &mut (dyn Read + Unpin); loop { let buf = decoder diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 6141b296f8..eea31a1105 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -1,8 +1,8 @@ use std::error::Error as StdError; +use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; use http::Request; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, trace}; use super::{Http1Transaction, Wants}; @@ -64,7 +64,7 @@ where RecvItem = MessageHead, > + Unpin, D::PollError: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, T: Http1Transaction + Unpin, Bs: Body + 'static, Bs::Error: Into>, @@ -82,7 +82,11 @@ where #[cfg(feature = "server")] pub(crate) fn disable_keep_alive(&mut self) { self.conn.disable_keep_alive(); - if self.conn.is_write_closed() { + + // If keep alive has been disabled and no read or write has been seen on + // the connection yet, we must be in a state where the server is being asked to + // shut down before any data has been seen on the connection + if self.conn.is_write_closed() || self.conn.has_initial_read_write_state() { self.close(); } } @@ -93,7 +97,7 @@ where } /// Run this dispatcher until HTTP says this connection is done, - /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// but don't call `Write::shutdown` on the underlying IO. /// /// This is useful for old-style HTTP upgrades, but ignores /// newer-style upgrade API. @@ -422,7 +426,7 @@ where RecvItem = MessageHead, > + Unpin, D::PollError: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, T: Http1Transaction + Unpin, Bs: Body + 'static, Bs::Error: Into>, @@ -660,6 +664,7 @@ cfg_client! { #[cfg(test)] mod tests { use super::*; + use crate::common::io::compat; use crate::proto::h1::ClientTransaction; use std::time::Duration; @@ -673,7 +678,7 @@ mod tests { // Block at 0 for now, but we will release this response before // the request is ready to write later... let (mut tx, rx) = crate::client::dispatch::channel(); - let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); let mut dispatcher = Dispatcher::new(Client::new(rx), conn); // First poll is needed to allow tx to send... @@ -710,7 +715,7 @@ mod tests { .build_with_handle(); let (mut tx, rx) = crate::client::dispatch::channel(); - let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); conn.set_write_strategy_queue(); let dispatcher = Dispatcher::new(Client::new(rx), conn); @@ -741,7 +746,7 @@ mod tests { .build(); let (mut tx, rx) = crate::client::dispatch::channel(); - let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); // First poll is needed to allow tx to send... diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index da4101b6fb..b49cda3dd3 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -6,8 +6,8 @@ use std::io::{self, IoSlice}; use std::marker::Unpin; use std::mem::MaybeUninit; +use crate::rt::{Read, ReadBuf, Write}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace}; use super::{Http1Transaction, ParseContext, ParsedMessage}; @@ -55,7 +55,7 @@ where impl Buffered where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, B: Buf, { pub(crate) fn new(io: T) -> Buffered { @@ -251,7 +251,7 @@ where let dst = self.read_buf.chunk_mut(); let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; let mut buf = ReadBuf::uninit(dst); - match Pin::new(&mut self.io).poll_read(cx, &mut buf) { + match Pin::new(&mut self.io).poll_read(cx, buf.unfilled()) { Poll::Ready(Ok(_)) => { let n = buf.filled().len(); trace!("received {} bytes", n); @@ -359,7 +359,7 @@ pub(crate) trait MemRead { impl MemRead for Buffered where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, B: Buf, { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { @@ -662,6 +662,7 @@ enum WriteStrategy { #[cfg(test)] mod tests { + use crate::common::io::compat; use crate::common::time::Time; use super::*; @@ -717,7 +718,7 @@ mod tests { .wait(Duration::from_secs(1)) .build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); // We expect a `parse` to be not ready, and so can't await it directly. // Rather, this `poll_fn` will wrap the `Poll` result. @@ -862,7 +863,7 @@ mod tests { #[cfg(debug_assertions)] // needs to trigger a debug_assert fn write_buf_requires_non_empty_bufs() { let mock = Mock::new().build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.buffer(Cursor::new(Vec::new())); } @@ -897,7 +898,7 @@ mod tests { let mock = Mock::new().write(b"hello world, it's hyper!").build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.write_buf.set_strategy(WriteStrategy::Flatten); buffered.headers_buf().extend(b"hello "); @@ -956,7 +957,7 @@ mod tests { .write(b"hyper!") .build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.write_buf.set_strategy(WriteStrategy::Queue); // we have 4 buffers, and vec IO disabled, but explicitly said diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index adadfce68d..b8d9951928 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -1,25 +1,31 @@ -use std::error::Error as StdError; +use std::marker::PhantomData; + use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::Bytes; +use futures_channel::mpsc::{Receiver, Sender}; use futures_channel::{mpsc, oneshot}; -use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; -use futures_util::stream::StreamExt as _; -use h2::client::{Builder, SendRequest}; +use futures_util::future::{self, Either, FutureExt as _, Select}; +use futures_util::stream::{StreamExt as _, StreamFuture}; +use h2::client::{Builder, Connection, SendRequest}; use h2::SendStream; use http::{Method, StatusCode}; -use tokio::io::{AsyncRead, AsyncWrite}; +use pin_project_lite::pin_project; use tracing::{debug, trace, warn}; +use super::ping::{Ponger, Recorder}; use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; use crate::body::{Body, Incoming as IncomingBody}; -use crate::client::dispatch::Callback; +use crate::client::dispatch::{Callback, SendWhen}; +use crate::common::io::Compat; use crate::common::time::Time; -use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; +use crate::common::{task, Future, Never, Pin, Poll}; use crate::ext::Protocol; use crate::headers; use crate::proto::h2::UpgradedSendStream; use crate::proto::Dispatched; +use crate::rt::bounds::ExecutorClient; use crate::upgrade::Upgraded; use crate::{Request, Response}; use h2::client::ResponseFuture; @@ -98,20 +104,22 @@ fn new_ping_config(config: &Config) -> ping::Config { } } -pub(crate) async fn handshake( +pub(crate) async fn handshake( io: T, req_rx: ClientRx, config: &Config, - exec: Exec, + mut exec: E, timer: Time, -) -> crate::Result> +) -> crate::Result> where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static, - B: Body, + T: Read + Write + Unpin + 'static, + B: Body + 'static, B::Data: Send + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, { let (h2_tx, mut conn) = new_builder(config) - .handshake::<_, SendBuf>(io) + .handshake::<_, SendBuf>(crate::common::io::compat(io)) .await .map_err(crate::Error::new_h2)?; @@ -122,40 +130,24 @@ where let (conn_drop_ref, rx) = mpsc::channel(1); let (cancel_tx, conn_eof) = oneshot::channel(); - let conn_drop_rx = rx.into_future().map(|(item, _rx)| { - if let Some(never) = item { - match never {} - } - }); + let conn_drop_rx = rx.into_future(); let ping_config = new_ping_config(&config); let (conn, ping) = if ping_config.is_enabled() { let pp = conn.ping_pong().expect("conn.ping_pong"); - let (recorder, mut ponger) = ping::channel(pp, ping_config, timer); + let (recorder, ponger) = ping::channel(pp, ping_config, timer); - let conn = future::poll_fn(move |cx| { - match ponger.poll(cx) { - Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { - conn.set_target_window_size(wnd); - conn.set_initial_window_size(wnd)?; - } - Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { - debug!("connection keep-alive timed out"); - return Poll::Ready(Ok(())); - } - Poll::Pending => {} - } - - Pin::new(&mut conn).poll(cx) - }); + let conn: Conn<_, B> = Conn::new(ponger, conn); (Either::Left(conn), recorder) } else { (Either::Right(conn), ping::disabled()) }; - let conn = conn.map_err(|e| debug!("connection error: {}", e)); + let conn: ConnMapErr = ConnMapErr { conn }; - exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + exec.execute_h2_future(H2ClientFuture::Task { + task: ConnTask::new(conn, conn_drop_rx, cancel_tx), + }); Ok(ClientTask { ping, @@ -165,25 +157,195 @@ where h2_tx, req_rx, fut_ctx: None, + marker: PhantomData, }) } -async fn conn_task(conn: C, drop_rx: D, cancel_tx: oneshot::Sender) +pin_project! { + struct Conn + where + B: Body, + { + #[pin] + ponger: Ponger, + #[pin] + conn: Connection, SendBuf<::Data>>, + } +} + +impl Conn +where + B: Body, + T: Read + Write + Unpin, +{ + fn new(ponger: Ponger, conn: Connection, SendBuf<::Data>>) -> Self { + Conn { ponger, conn } + } +} + +impl Future for Conn where - C: Future + Unpin, - D: Future + Unpin, + B: Body, + T: Read + Write + Unpin, { - match future::select(conn, drop_rx).await { - Either::Left(_) => { - // ok or err, the `conn` has finished + type Output = Result<(), h2::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + match this.ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + this.conn.set_target_window_size(wnd); + this.conn.set_initial_window_size(wnd)?; + } + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } + Poll::Pending => {} } - Either::Right(((), conn)) => { - // mpsc has been dropped, hopefully polling - // the connection some more should start shutdown - // and then close - trace!("send_request dropped, starting conn shutdown"); - drop(cancel_tx); - let _ = conn.await; + + Pin::new(&mut this.conn).poll(cx) + } +} + +pin_project! { + struct ConnMapErr + where + B: Body, + T: Read, + T: Write, + T: Unpin, + { + #[pin] + conn: Either, Connection, SendBuf<::Data>>>, + } +} + +impl Future for ConnMapErr +where + B: Body, + T: Read + Write + Unpin, +{ + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + self.project() + .conn + .poll(cx) + .map_err(|e| debug!("connection error: {}", e)) + } +} + +pin_project! { + pub struct ConnTask + where + B: Body, + T: Read, + T: Write, + T: Unpin, + { + #[pin] + select: Select, StreamFuture>>, + #[pin] + cancel_tx: Option>, + conn: Option>, + } +} + +impl ConnTask +where + B: Body, + T: Read + Write + Unpin, +{ + fn new( + conn: ConnMapErr, + drop_rx: StreamFuture>, + cancel_tx: oneshot::Sender, + ) -> Self { + Self { + select: future::select(conn, drop_rx), + cancel_tx: Some(cancel_tx), + conn: None, + } + } +} + +impl Future for ConnTask +where + B: Body, + T: Read + Write + Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + if let Some(conn) = this.conn { + conn.poll_unpin(cx).map(|_| ()) + } else { + match ready!(this.select.poll_unpin(cx)) { + Either::Left((_, _)) => { + // ok or err, the `conn` has finished + return Poll::Ready(()); + } + Either::Right((_, b)) => { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close + trace!("send_request dropped, starting conn shutdown"); + drop(this.cancel_tx.take().expect("Future polled twice")); + this.conn = &mut Some(b); + return Poll::Pending; + } + } + } + } +} + +pin_project! { + #[project = H2ClientFutureProject] + pub enum H2ClientFuture + where + B: http_body::Body, + B: 'static, + B::Error: Into>, + T: Read, + T: Write, + T: Unpin, + { + Pipe { + #[pin] + pipe: PipeMap, + }, + Send { + #[pin] + send_when: SendWhen, + }, + Task { + #[pin] + task: ConnTask, + }, + } +} + +impl Future for H2ClientFuture +where + B: http_body::Body + 'static, + B::Error: Into>, + T: Read + Write + Unpin, +{ + type Output = (); + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + + match this { + H2ClientFutureProject::Pipe { pipe } => pipe.poll(cx), + H2ClientFutureProject::Send { send_when } => send_when.poll(cx), + H2ClientFutureProject::Task { task } => task.poll(cx), } } } @@ -202,43 +364,89 @@ where impl Unpin for FutCtx {} -pub(crate) struct ClientTask +pub(crate) struct ClientTask where B: Body, + E: Unpin, { ping: ping::Recorder, conn_drop_ref: ConnDropRef, conn_eof: ConnEof, - executor: Exec, + executor: E, h2_tx: SendRequest>, req_rx: ClientRx, fut_ctx: Option>, + marker: PhantomData, } -impl ClientTask +impl ClientTask where B: Body + 'static, + E: ExecutorClient + Unpin, + B::Error: Into>, + T: Read + Write + Unpin, { pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { self.h2_tx.is_extended_connect_protocol_enabled() } } -impl ClientTask +pin_project! { + pub struct PipeMap + where + S: Body, + { + #[pin] + pipe: PipeToSendStream, + #[pin] + conn_drop_ref: Option>, + #[pin] + ping: Option, + } +} + +impl Future for PipeMap where - B: Body + Send + 'static, + B: http_body::Body, + B::Error: Into>, +{ + type Output = (); + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let mut this = self.project(); + + match this.pipe.poll_unpin(cx) { + Poll::Ready(result) => { + if let Err(e) = result { + debug!("client request body error: {}", e); + } + drop(this.conn_drop_ref.take().expect("Future polled twice")); + drop(this.ping.take().expect("Future polled twice")); + return Poll::Ready(()); + } + Poll::Pending => (), + }; + Poll::Pending + } +} + +impl ClientTask +where + B: Body + 'static + Unpin, B::Data: Send, - B::Error: Into>, + E: ExecutorClient + Unpin, + B::Error: Into>, + T: Read + Write + Unpin, { fn poll_pipe(&mut self, f: FutCtx, cx: &mut task::Context<'_>) { let ping = self.ping.clone(); + let send_stream = if !f.is_connect { if !f.eos { - let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| { - if let Err(e) = res { - debug!("client request body error: {}", e); - } - }); + let mut pipe = PipeToSendStream::new(f.body, f.body_tx); // eagerly see if the body pipe is ready and // can thus skip allocating in the executor @@ -250,13 +458,15 @@ where // "open stream" alive while this body is // still sending... let ping = ping.clone(); - let pipe = pipe.map(move |x| { - drop(conn_drop_ref); - drop(ping); - x - }); + + let pipe = PipeMap { + pipe, + conn_drop_ref: Some(conn_drop_ref), + ping: Some(ping), + }; // Clear send task - self.executor.execute(pipe); + self.executor + .execute_h2_future(H2ClientFuture::Pipe { pipe: pipe }); } } } @@ -266,7 +476,49 @@ where Some(f.body_tx) }; - let fut = f.fut.map(move |result| match result { + self.executor.execute_h2_future(H2ClientFuture::Send { + send_when: SendWhen { + when: ResponseFutMap { + fut: f.fut, + ping: Some(ping), + send_stream: Some(send_stream), + }, + call_back: Some(f.cb), + }, + }); + } +} + +pin_project! { + pub(crate) struct ResponseFutMap + where + B: Body, + B: 'static, + { + #[pin] + fut: ResponseFuture, + #[pin] + ping: Option, + #[pin] + send_stream: Option::Data>>>>, + } +} + +impl Future for ResponseFutMap +where + B: Body + 'static, +{ + type Output = Result, (crate::Error, Option>)>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + let result = ready!(this.fut.poll(cx)); + + let ping = this.ping.take().expect("Future polled twice"); + let send_stream = this.send_stream.take().expect("Future polled twice"); + + match result { Ok(res) => { // record that we got the response headers ping.record_non_data(); @@ -277,17 +529,17 @@ where warn!("h2 connect response with non-zero body not supported"); send_stream.send_reset(h2::Reason::INTERNAL_ERROR); - return Err(( + return Poll::Ready(Err(( crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), - None, - )); + None::>, + ))); } let (parts, recv_stream) = res.into_parts(); let mut res = Response::from_parts(parts, IncomingBody::empty()); let (pending, on_upgrade) = crate::upgrade::pending(); let io = H2Upgraded { - ping, + ping: ping, send_stream: unsafe { UpgradedSendStream::new(send_stream) }, recv_stream, buf: Bytes::new(), @@ -297,31 +549,32 @@ where pending.fulfill(upgraded); res.extensions_mut().insert(on_upgrade); - Ok(res) + Poll::Ready(Ok(res)) } else { let res = res.map(|stream| { let ping = ping.for_stream(&stream); IncomingBody::h2(stream, content_length.into(), ping) }); - Ok(res) + Poll::Ready(Ok(res)) } } Err(err) => { ping.ensure_not_timed_out().map_err(|e| (e, None))?; debug!("client response error: {}", err); - Err((crate::Error::new_h2(err), None)) + Poll::Ready(Err((crate::Error::new_h2(err), None::>))) } - }); - self.executor.execute(f.cb.send_when(fut)); + } } } -impl Future for ClientTask +impl Future for ClientTask where - B: Body + Send + 'static, + B: Body + 'static + Unpin, B::Data: Send, - B::Error: Into>, + B::Error: Into>, + E: ExecutorClient + 'static + Send + Sync + Unpin, + T: Read + Write + Unpin, { type Output = crate::Result; diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index c81c0b4665..2002edeb13 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,13 +1,13 @@ +use crate::rt::{Read, ReadBufCursor, Write}; use bytes::{Buf, Bytes}; use h2::{Reason, RecvStream, SendStream}; use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE}; use http::HeaderMap; use pin_project_lite::pin_project; use std::error::Error as StdError; -use std::io::{self, Cursor, IoSlice}; +use std::io::{Cursor, IoSlice}; use std::mem; use std::task::Context; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace, warn}; use crate::body::Body; @@ -85,7 +85,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { // body adapters used by both Client and Server pin_project! { - struct PipeToSendStream + pub(crate) struct PipeToSendStream where S: Body, { @@ -271,15 +271,15 @@ where buf: Bytes, } -impl AsyncRead for H2Upgraded +impl Read for H2Upgraded where B: Buf, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - read_buf: &mut ReadBuf<'_>, - ) -> Poll> { + mut read_buf: ReadBufCursor<'_>, + ) -> Poll> { if self.buf.is_empty() { self.buf = loop { match ready!(self.recv_stream.poll_data(cx)) { @@ -295,7 +295,7 @@ where return Poll::Ready(match e.reason() { Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), Some(Reason::STREAM_CLOSED) => { - Err(io::Error::new(io::ErrorKind::BrokenPipe, e)) + Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) } _ => Err(h2_to_io_error(e)), }) @@ -311,7 +311,7 @@ where } } -impl AsyncWrite for H2Upgraded +impl Write for H2Upgraded where B: Buf, { @@ -319,7 +319,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { + ) -> Poll> { if buf.is_empty() { return Poll::Ready(Ok(0)); } @@ -344,7 +344,7 @@ where Poll::Ready(Err(h2_to_io_error( match ready!(self.send_stream.poll_reset(cx)) { Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Ok(reason) => reason.into(), Err(e) => e, @@ -352,14 +352,14 @@ where ))) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { if self.send_stream.write(&[], true).is_ok() { return Poll::Ready(Ok(())); } @@ -368,7 +368,7 @@ where match ready!(self.send_stream.poll_reset(cx)) { Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())), Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Ok(reason) => reason.into(), Err(e) => e, @@ -377,11 +377,11 @@ where } } -fn h2_to_io_error(e: h2::Error) -> io::Error { +fn h2_to_io_error(e: h2::Error) -> std::io::Error { if e.is_io() { e.into_io().unwrap() } else { - io::Error::new(io::ErrorKind::Other, e) + std::io::Error::new(std::io::ErrorKind::Other, e) } } @@ -408,7 +408,7 @@ where unsafe { self.as_inner_unchecked().poll_reset(cx) } } - fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), std::io::Error> { let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); unsafe { self.as_inner_unchecked() diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index bf458f428c..0913f314c9 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -3,12 +3,12 @@ use std::marker::Unpin; use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::Bytes; use h2::server::{Connection, Handshake, SendResponse}; use h2::{Reason, RecvStream}; use http::{Method, Request}; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, trace, warn}; use super::{ping, PipeToSendStream, SendBuf}; @@ -89,7 +89,7 @@ where { Handshaking { ping_config: ping::Config, - hs: Handshake>, + hs: Handshake, SendBuf>, }, Serving(Serving), Closed, @@ -100,13 +100,13 @@ where B: Body, { ping: Option<(ping::Recorder, ping::Ponger)>, - conn: Connection>, + conn: Connection, SendBuf>, closing: Option, } impl Server where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, S: HttpService, S::Error: Into>, B: Body + 'static, @@ -132,7 +132,7 @@ where if config.enable_connect_protocol { builder.enable_connect_protocol(); } - let handshake = builder.handshake(io); + let handshake = builder.handshake(crate::common::io::compat(io)); let bdp = if config.adaptive_window { Some(config.initial_stream_window_size) @@ -182,7 +182,7 @@ where impl Future for Server where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, S: HttpService, S::Error: Into>, B: Body + 'static, @@ -228,7 +228,7 @@ where impl Serving where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, B: Body + 'static, { fn poll_server( diff --git a/src/rt/bounds.rs b/src/rt/bounds.rs index 69115ef2ca..36f3683ead 100644 --- a/src/rt/bounds.rs +++ b/src/rt/bounds.rs @@ -6,14 +6,18 @@ #[cfg(all(feature = "server", feature = "http2"))] pub use self::h2::Http2ConnExec; -#[cfg(all(feature = "server", feature = "http2"))] +#[cfg(all(feature = "client", feature = "http2"))] +pub use self::h2_client::ExecutorClient; + +#[cfg(all(feature = "client", feature = "http2"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "http2"))))] -mod h2 { - use crate::{common::exec::Exec, proto::h2::server::H2Stream, rt::Executor}; - use http_body::Body; - use std::future::Future; +mod h2_client { + use std::{error::Error, future::Future}; - /// An executor to spawn http2 connections. + use crate::rt::{Read, Write}; + use crate::{proto::h2::client::H2ClientFuture, rt::Executor}; + + /// An executor to spawn http2 futures for the client. /// /// This trait is implemented for any type that implements [`Executor`] /// trait for any future. @@ -21,28 +25,64 @@ mod h2 { /// This trait is sealed and cannot be implemented for types outside this crate. /// /// [`Executor`]: crate::rt::Executor - pub trait Http2ConnExec: sealed::Sealed<(F, B)> + Clone { + pub trait ExecutorClient: sealed_client::Sealed<(B, T)> + where + B: http_body::Body, + B::Error: Into>, + T: Read + Write + Unpin, + { #[doc(hidden)] - fn execute_h2stream(&mut self, fut: H2Stream); + fn execute_h2_future(&mut self, future: H2ClientFuture); } - impl Http2ConnExec for Exec + impl ExecutorClient for E where - H2Stream: Future + Send + 'static, - B: Body, + E: Executor>, + B: http_body::Body + 'static, + B::Error: Into>, + H2ClientFuture: Future, + T: Read + Write + Unpin, { - fn execute_h2stream(&mut self, fut: H2Stream) { - self.execute(fut) + fn execute_h2_future(&mut self, future: H2ClientFuture) { + self.execute(future) } } - impl sealed::Sealed<(F, B)> for Exec + impl sealed_client::Sealed<(B, T)> for E where - H2Stream: Future + Send + 'static, - B: Body, + E: Executor>, + B: http_body::Body + 'static, + B::Error: Into>, + H2ClientFuture: Future, + T: Read + Write + Unpin, { } + mod sealed_client { + pub trait Sealed {} + } +} + +#[cfg(all(feature = "server", feature = "http2"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "http2"))))] +mod h2 { + use crate::{proto::h2::server::H2Stream, rt::Executor}; + use http_body::Body; + use std::future::Future; + + /// An executor to spawn http2 connections. + /// + /// This trait is implemented for any type that implements [`Executor`] + /// trait for any future. + /// + /// This trait is sealed and cannot be implemented for types outside this crate. + /// + /// [`Executor`]: crate::rt::Executor + pub trait Http2ConnExec: sealed::Sealed<(F, B)> + Clone { + #[doc(hidden)] + fn execute_h2stream(&mut self, fut: H2Stream); + } + #[doc(hidden)] impl Http2ConnExec for E where diff --git a/src/rt/io.rs b/src/rt/io.rs new file mode 100644 index 0000000000..c39e1e098d --- /dev/null +++ b/src/rt/io.rs @@ -0,0 +1,334 @@ +use std::fmt; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +// New IO traits? What?! Why, are you bonkers? +// +// I mean, yes, probably. But, here's the goals: +// +// 1. Supports poll-based IO operations. +// 2. Opt-in vectored IO. +// 3. Can use an optional buffer pool. +// 4. Able to add completion-based (uring) IO eventually. +// +// Frankly, the last point is the entire reason we're doing this. We want to +// have forwards-compatibility with an eventually stable io-uring runtime. We +// don't need that to work right away. But it must be possible to add in here +// without breaking hyper 1.0. +// +// While in here, if there's small tweaks to poll_read or poll_write that would +// allow even the "slow" path to be faster, such as if someone didn't remember +// to forward along an `is_completion` call. + +/// Reads bytes from a source. +/// +/// This trait is similar to `std::io::Read`, but supports asynchronous reads. +pub trait Read { + /// Attempts to read bytes into the `buf`. + /// + /// On success, returns `Poll::Ready(Ok(()))` and places data in the + /// unfilled portion of `buf`. If no data was read (`buf.remaining()` is + /// unchanged), it implies that EOF has been reached. + /// + /// If no data is available for reading, the method returns `Poll::Pending` + /// and arranges for the current task (via `cx.waker()`) to receive a + /// notification when the object becomes readable or is closed. + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll>; +} + +/// Write bytes asynchronously. +/// +/// This trait is similar to `std::io::Write`, but for asynchronous writes. +pub trait Write { + /// Attempt to write bytes from `buf` into the destination. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written)))`. If + /// successful, it must be guaranteed that `n <= buf.len()`. A return value + /// of `0` means that the underlying object is no longer able to accept + /// bytes, or that the provided buffer is empty. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via `cx.waker()`) to + /// receive a notification when the object becomes writable or is closed. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + /// Attempts to flush the object. + /// + /// On success, returns `Poll::Ready(Ok(()))`. + /// + /// If flushing cannot immediately complete, this method returns + /// `Poll::Pending` and arranges for the current task (via `cx.waker()`) to + /// receive a notification when the object can make progress. + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + /// Attempts to shut down this writer. + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; + + /// Returns whether this writer has an efficient `poll_write_vectored` + /// implementation. + /// + /// The default implementation returns `false`. + fn is_write_vectored(&self) -> bool { + false + } + + /// Like `poll_write`, except that it writes from a slice of buffers. + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } +} + +/// A wrapper around a byte buffer that is incrementally filled and initialized. +/// +/// This type is a sort of "double cursor". It tracks three regions in the +/// buffer: a region at the beginning of the buffer that has been logically +/// filled with data, a region that has been initialized at some point but not +/// yet logically filled, and a region at the end that may be uninitialized. +/// The filled region is guaranteed to be a subset of the initialized region. +/// +/// In summary, the contents of the buffer can be visualized as: +/// +/// ```not_rust +/// [ capacity ] +/// [ filled | unfilled ] +/// [ initialized | uninitialized ] +/// ``` +/// +/// It is undefined behavior to de-initialize any bytes from the uninitialized +/// region, since it is merely unknown whether this region is uninitialized or +/// not, and if part of it turns out to be initialized, it must stay initialized. +pub struct ReadBuf<'a> { + raw: &'a mut [MaybeUninit], + filled: usize, + init: usize, +} + +/// The cursor part of a [`ReadBuf`]. +/// +/// This is created by calling `ReadBuf::unfilled()`. +#[derive(Debug)] +pub struct ReadBufCursor<'a> { + buf: &'a mut ReadBuf<'a>, +} + +impl<'data> ReadBuf<'data> { + #[inline] + #[cfg(test)] + pub(crate) fn new(raw: &'data mut [u8]) -> Self { + let len = raw.len(); + Self { + // SAFETY: We never de-init the bytes ourselves. + raw: unsafe { &mut *(raw as *mut [u8] as *mut [MaybeUninit]) }, + filled: 0, + init: len, + } + } + + /// Create a new `ReadBuf` with a slice of uninitialized bytes. + #[inline] + pub fn uninit(raw: &'data mut [MaybeUninit]) -> Self { + Self { + raw, + filled: 0, + init: 0, + } + } + + /// Get a slice of the buffer that has been filled in with bytes. + #[inline] + pub fn filled(&self) -> &[u8] { + // SAFETY: We only slice the filled part of the buffer, which is always valid + unsafe { &*(&self.raw[0..self.filled] as *const [MaybeUninit] as *const [u8]) } + } + + /// Get a cursor to the unfilled portion of the buffer. + #[inline] + pub fn unfilled<'cursor>(&'cursor mut self) -> ReadBufCursor<'cursor> { + ReadBufCursor { + // SAFETY: self.buf is never re-assigned, so its safe to narrow + // the lifetime. + buf: unsafe { + std::mem::transmute::<&'cursor mut ReadBuf<'data>, &'cursor mut ReadBuf<'cursor>>( + self, + ) + }, + } + } + + #[inline] + pub(crate) unsafe fn set_init(&mut self, n: usize) { + self.init = self.init.max(n); + } + + #[inline] + pub(crate) unsafe fn set_filled(&mut self, n: usize) { + self.filled = self.filled.max(n); + } + + #[inline] + pub(crate) fn len(&self) -> usize { + self.filled + } + + #[inline] + pub(crate) fn init_len(&self) -> usize { + self.init + } + + #[inline] + fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + #[inline] + fn capacity(&self) -> usize { + self.raw.len() + } +} + +impl<'data> fmt::Debug for ReadBuf<'data> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("init", &self.init) + .field("capacity", &self.capacity()) + .finish() + } +} + +impl<'data> ReadBufCursor<'data> { + /// Access the unfilled part of the buffer. + /// + /// # Safety + /// + /// The caller must not uninitialize any bytes that may have been + /// initialized before. + #[inline] + pub unsafe fn as_mut(&mut self) -> &mut [MaybeUninit] { + &mut self.buf.raw[self.buf.filled..] + } + + /// Advance the `filled` cursor by `n` bytes. + /// + /// # Safety + /// + /// The caller must take care that `n` more bytes have been initialized. + #[inline] + pub unsafe fn advance(&mut self, n: usize) { + self.buf.filled = self.buf.filled.checked_add(n).expect("overflow"); + self.buf.init = self.buf.filled.max(self.buf.init); + } + + #[inline] + pub(crate) fn remaining(&self) -> usize { + self.buf.remaining() + } + + #[inline] + pub(crate) fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.buf.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.buf.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf.raw[self.buf.filled..end] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.buf.init < end { + self.buf.init = end; + } + self.buf.filled = end; + } +} + +macro_rules! deref_async_read { + () => { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll> { + Pin::new(&mut **self).poll_read(cx, buf) + } + }; +} + +impl Read for Box { + deref_async_read!(); +} + +impl Read for &mut T { + deref_async_read!(); +} + +macro_rules! deref_async_write { + () => { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut **self).poll_shutdown(cx) + } + }; +} + +impl Write for Box { + deref_async_write!(); +} + +impl Write for &mut T { + deref_async_write!(); +} diff --git a/src/rt/mod.rs b/src/rt/mod.rs index 803d010f40..de67c3fc89 100644 --- a/src/rt/mod.rs +++ b/src/rt/mod.rs @@ -1,17 +1,18 @@ //! Runtime components //! -//! By default, hyper includes the [tokio](https://tokio.rs) runtime. +//! The traits and types within this module are used to allow plugging in +//! runtime types. These include: //! -//! If the `runtime` feature is disabled, the types in this module can be used -//! to plug in other runtimes. +//! - Executors +//! - Timers +//! - IO transports pub mod bounds; +mod io; +mod timer; -use std::{ - future::Future, - pin::Pin, - time::{Duration, Instant}, -}; +pub use self::io::{Read, ReadBuf, ReadBufCursor, Write}; +pub use self::timer::{Sleep, Timer}; /// An executor of futures. /// @@ -39,20 +40,3 @@ pub trait Executor { /// Place the future into the executor to be run. fn execute(&self, fut: Fut); } - -/// A timer which provides timer-like functions. -pub trait Timer { - /// Return a future that resolves in `duration` time. - fn sleep(&self, duration: Duration) -> Pin>; - - /// Return a future that resolves at `deadline`. - fn sleep_until(&self, deadline: Instant) -> Pin>; - - /// Reset a future to resolve at `new_deadline` instead. - fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { - *sleep = self.sleep_until(new_deadline); - } -} - -/// A future returned by a `Timer`. -pub trait Sleep: Send + Sync + Future {} diff --git a/src/rt/timer.rs b/src/rt/timer.rs new file mode 100644 index 0000000000..6ecb964373 --- /dev/null +++ b/src/rt/timer.rs @@ -0,0 +1,127 @@ +//! Provides a timer trait with timer-like functions +//! +//! Example using tokio timer: +//! ```rust +//! use std::{ +//! pin::Pin, +//! task::{Context, Poll}, +//! time::{Duration, Instant}, +//! }; +//! +//! use futures_util::Future; +//! use pin_project_lite::pin_project; +//! use hyper::rt::{Timer, Sleep}; +//! +//! #[derive(Clone, Debug)] +//! pub struct TokioTimer; +//! +//! impl Timer for TokioTimer { +//! fn sleep(&self, duration: Duration) -> Pin> { +//! Box::pin(TokioSleep { +//! inner: tokio::time::sleep(duration), +//! }) +//! } +//! +//! fn sleep_until(&self, deadline: Instant) -> Pin> { +//! Box::pin(TokioSleep { +//! inner: tokio::time::sleep_until(deadline.into()), +//! }) +//! } +//! +//! fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { +//! if let Some(sleep) = sleep.as_mut().downcast_mut_pin::() { +//! sleep.reset(new_deadline.into()) +//! } +//! } +//! } +//! +//! pin_project! { +//! pub(crate) struct TokioSleep { +//! #[pin] +//! pub(crate) inner: tokio::time::Sleep, +//! } +//! } +//! +//! impl Future for TokioSleep { +//! type Output = (); +//! +//! fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { +//! self.project().inner.poll(cx) +//! } +//! } +//! +//! impl Sleep for TokioSleep {} +//! +//! impl TokioSleep { +//! pub fn reset(self: Pin<&mut Self>, deadline: Instant) { +//! self.project().inner.as_mut().reset(deadline.into()); +//! } +//! } +//! ```` + +use std::{ + any::TypeId, + future::Future, + pin::Pin, + time::{Duration, Instant}, +}; + +/// A timer which provides timer-like functions. +pub trait Timer { + /// Return a future that resolves in `duration` time. + fn sleep(&self, duration: Duration) -> Pin>; + + /// Return a future that resolves at `deadline`. + fn sleep_until(&self, deadline: Instant) -> Pin>; + + /// Reset a future to resolve at `new_deadline` instead. + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + *sleep = self.sleep_until(new_deadline); + } +} + +/// A future returned by a `Timer`. +pub trait Sleep: Send + Sync + Future { + #[doc(hidden)] + /// This method is private and can not be implemented by downstream crate + fn __type_id(&self, _: private::Sealed) -> TypeId + where + Self: 'static, + { + TypeId::of::() + } +} + +impl dyn Sleep { + //! This is a re-implementation of downcast methods from std::any::Any + + /// Check whether the type is the same as `T` + pub fn is(&self) -> bool + where + T: Sleep + 'static, + { + self.__type_id(private::Sealed {}) == TypeId::of::() + } + + /// Downcast a pinned &mut Sleep object to its original type + pub fn downcast_mut_pin(self: Pin<&mut Self>) -> Option> + where + T: Sleep + 'static, + { + if self.is::() { + unsafe { + let inner = Pin::into_inner_unchecked(self); + Some(Pin::new_unchecked( + &mut *(&mut *inner as *mut dyn Sleep as *mut T), + )) + } + } else { + None + } + } +} + +mod private { + #![allow(missing_debug_implementations)] + pub struct Sealed {} +} diff --git a/src/server/conn/http1.rs b/src/server/conn/http1.rs index 530082e966..09770cd3cb 100644 --- a/src/server/conn/http1.rs +++ b/src/server/conn/http1.rs @@ -5,8 +5,8 @@ use std::fmt; use std::sync::Arc; use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::Bytes; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::{task, Future, Pin, Poll, Unpin}; @@ -85,7 +85,7 @@ impl Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, { @@ -172,7 +172,7 @@ impl Future for Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, B: Body + 'static, B::Error: Into>, { @@ -333,10 +333,10 @@ impl Builder { /// # use hyper::{body::Incoming, Request, Response}; /// # use hyper::service::Service; /// # use hyper::server::conn::http1::Builder; - /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use hyper::rt::{Read, Write}; /// # async fn run(some_io: I, some_service: S) /// # where - /// # I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + /// # I: Read + Write + Unpin + Send + 'static, /// # S: Service, Response=hyper::Response> + Send + 'static, /// # S::Error: Into>, /// # S::Future: Send, @@ -356,7 +356,7 @@ impl Builder { S::Error: Into>, S::ResBody: 'static, ::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, { let mut conn = proto::Conn::new(io); conn.set_timer(self.timer.clone()); @@ -413,7 +413,7 @@ mod upgrades { where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, { @@ -430,7 +430,7 @@ mod upgrades { where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + I: Read + Write + Unpin + Send + 'static, B: Body + 'static, B::Error: Into>, { diff --git a/src/server/conn/http2.rs b/src/server/conn/http2.rs index e1345f3b6b..f6f09f45f1 100644 --- a/src/server/conn/http2.rs +++ b/src/server/conn/http2.rs @@ -5,8 +5,8 @@ use std::fmt; use std::sync::Arc; use std::time::Duration; +use crate::rt::{Read, Write}; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::{task, Future, Pin, Poll, Unpin}; @@ -51,7 +51,7 @@ impl Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, E: Http2ConnExec, @@ -75,7 +75,7 @@ impl Future for Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, B: Body + 'static, B::Error: Into>, E: Http2ConnExec, @@ -118,7 +118,7 @@ impl Builder { /// /// If not set, hyper will use a default. /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + /// [spec]: https://httpwg.org/specs/rfc9113.html#SETTINGS_INITIAL_WINDOW_SIZE pub fn initial_stream_window_size(&mut self, sz: impl Into>) -> &mut Self { if let Some(sz) = sz.into() { self.h2_builder.adaptive_window = false; @@ -173,7 +173,7 @@ impl Builder { /// /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + /// [spec]: https://httpwg.org/specs/rfc9113.html#SETTINGS_MAX_CONCURRENT_STREAMS pub fn max_concurrent_streams(&mut self, max: impl Into>) -> &mut Self { self.h2_builder.max_concurrent_streams = max.into(); self @@ -255,7 +255,7 @@ impl Builder { S::Error: Into>, Bd: Body + 'static, Bd::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, E: Http2ConnExec, { let proto = proto::h2::Server::new( diff --git a/src/server/conn/mod.rs b/src/server/conn/mod.rs index f2abae22aa..b7dea1b8c6 100644 --- a/src/server/conn/mod.rs +++ b/src/server/conn/mod.rs @@ -7,43 +7,6 @@ //! //! This module is split by HTTP version. Both work similarly, but do have //! specific options on each builder. -//! -//! ## Example -//! -//! A simple example that prepares an HTTP/1 connection over a Tokio TCP stream. -//! -//! ```no_run -//! # #[cfg(feature = "http1")] -//! # mod rt { -//! use http::{Request, Response, StatusCode}; -//! use http_body_util::Full; -//! use hyper::{server::conn::http1, service::service_fn, body, body::Bytes}; -//! use std::{net::SocketAddr, convert::Infallible}; -//! use tokio::net::TcpListener; -//! -//! #[tokio::main] -//! async fn main() -> Result<(), Box> { -//! let addr: SocketAddr = ([127, 0, 0, 1], 8080).into(); -//! -//! let mut tcp_listener = TcpListener::bind(addr).await?; -//! loop { -//! let (tcp_stream, _) = tcp_listener.accept().await?; -//! tokio::task::spawn(async move { -//! if let Err(http_err) = http1::Builder::new() -//! .keep_alive(true) -//! .serve_connection(tcp_stream, service_fn(hello)) -//! .await { -//! eprintln!("Error while serving HTTP connection: {}", http_err); -//! } -//! }); -//! } -//! } -//! -//! async fn hello(_req: Request) -> Result>, Infallible> { -//! Ok(Response::new(Full::new(Bytes::from("Hello World!")))) -//! } -//! # } -//! ``` #[cfg(feature = "http1")] pub mod http1; diff --git a/src/upgrade.rs b/src/upgrade.rs index 1c7b5b01cd..231578f913 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -45,8 +45,8 @@ use std::fmt; use std::io; use std::marker::Unpin; +use crate::rt::{Read, ReadBufCursor, Write}; use bytes::Bytes; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::oneshot; #[cfg(any(feature = "http1", feature = "http2"))] use tracing::trace; @@ -122,7 +122,7 @@ impl Upgraded { #[cfg(any(feature = "http1", feature = "http2", test))] pub(super) fn new(io: T, read_buf: Bytes) -> Self where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, { Upgraded { io: Rewind::new_buffered(Box::new(io), read_buf), @@ -133,7 +133,7 @@ impl Upgraded { /// /// On success, returns the downcasted parts. On error, returns the /// `Upgraded` back. - pub fn downcast(self) -> Result, Self> { + pub fn downcast(self) -> Result, Self> { let (io, buf) = self.io.into_inner(); match io.__hyper_downcast() { Ok(t) => Ok(Parts { @@ -148,17 +148,17 @@ impl Upgraded { } } -impl AsyncRead for Upgraded { +impl Read for Upgraded { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.io).poll_read(cx, buf) } } -impl AsyncWrite for Upgraded { +impl Write for Upgraded { fn poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, @@ -265,13 +265,13 @@ impl StdError for UpgradeExpected {} // ===== impl Io ===== -pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { +pub(super) trait Io: Read + Write + Unpin + 'static { fn __hyper_type_id(&self) -> TypeId { TypeId::of::() } } -impl Io for T {} +impl Io for T {} impl dyn Io + Send { fn __hyper_is(&self) -> bool { @@ -340,7 +340,9 @@ mod tests { fn upgraded_downcast() { let upgraded = Upgraded::new(Mock, Bytes::new()); - let upgraded = upgraded.downcast::>>().unwrap_err(); + let upgraded = upgraded + .downcast::>>>() + .unwrap_err(); upgraded.downcast::().unwrap(); } @@ -348,17 +350,17 @@ mod tests { // TODO: replace with tokio_test::io when it can test write_buf struct Mock; - impl AsyncRead for Mock { + impl Read for Mock { fn poll_read( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, - _buf: &mut ReadBuf<'_>, + _buf: ReadBufCursor<'_>, ) -> Poll> { unreachable!("Mock::poll_read") } } - impl AsyncWrite for Mock { + impl Write for Mock { fn poll_write( self: Pin<&mut Self>, _: &mut task::Context<'_>, diff --git a/tests/client.rs b/tests/client.rs index 842282c5bb..ef80596c01 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -22,6 +22,7 @@ use hyper::{Method, Request, StatusCode, Uri, Version}; use bytes::Bytes; use futures_channel::oneshot; use futures_util::future::{self, FutureExt, TryFuture, TryFutureExt}; +use support::TokioIo; use tokio::net::TcpStream; mod support; @@ -36,8 +37,8 @@ where b.collect().await.map(|c| c.to_bytes()) } -fn tcp_connect(addr: &SocketAddr) -> impl Future> { - TcpStream::connect(*addr) +async fn tcp_connect(addr: &SocketAddr) -> std::io::Result> { + TcpStream::connect(*addr).await.map(TokioIo::new) } struct HttpInfo { @@ -312,7 +313,7 @@ macro_rules! test { req.headers_mut().append("Host", HeaderValue::from_str(&host).unwrap()); } - let (mut sender, conn) = builder.handshake(stream).await?; + let (mut sender, conn) = builder.handshake(TokioIo::new(stream)).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -1339,7 +1340,7 @@ mod conn { use futures_util::future::{self, poll_fn, FutureExt, TryFutureExt}; use http_body_util::{BodyExt, Empty, StreamBody}; use hyper::rt::Timer; - use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf}; + use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; use tokio::net::{TcpListener as TkTcpListener, TcpStream}; use hyper::body::{Body, Frame}; @@ -1349,7 +1350,7 @@ mod conn { use super::{concat, s, support, tcp_connect, FutureHyperExt}; - use support::{TokioExecutor, TokioTimer}; + use support::{TokioExecutor, TokioIo, TokioTimer}; fn setup_logger() { let _ = pretty_env_logger::try_init(); @@ -1773,7 +1774,7 @@ mod conn { } let parts = conn.into_parts(); - let mut io = parts.io; + let io = parts.io; let buf = parts.read_buf; assert_eq!(buf, b"foobar=ready"[..]); @@ -1785,6 +1786,7 @@ mod conn { })) .unwrap_err(); + let mut io = io.tcp.inner(); let mut vec = vec![]; rt.block_on(io.write_all(b"foo=bar")).unwrap(); rt.block_on(io.read_to_end(&mut vec)).unwrap(); @@ -1861,7 +1863,7 @@ mod conn { } let parts = conn.into_parts(); - let mut io = parts.io; + let io = parts.io; let buf = parts.read_buf; assert_eq!(buf, b"foobar=ready"[..]); @@ -1874,6 +1876,7 @@ mod conn { })) .unwrap_err(); + let mut io = io.tcp.inner(); let mut vec = vec![]; rt.block_on(io.write_all(b"foo=bar")).unwrap(); rt.block_on(io.read_to_end(&mut vec)).unwrap(); @@ -1895,6 +1898,7 @@ mod conn { tokio::select! { res = listener.accept() => { let (stream, _) = res.unwrap(); + let stream = TokioIo::new(stream); let service = service_fn(|_:Request| future::ok::<_, hyper::Error>(Response::new(Empty::::new()))); @@ -2077,7 +2081,7 @@ mod conn { // Spawn an HTTP2 server that reads the whole body and responds tokio::spawn(async move { - let sock = listener.accept().await.unwrap().0; + let sock = TokioIo::new(listener.accept().await.unwrap().0); hyper::server::conn::http2::Builder::new(TokioExecutor) .timer(TokioTimer) .serve_connection( @@ -2166,7 +2170,7 @@ mod conn { let res = client.send_request(req).await.expect("send_request"); assert_eq!(res.status(), StatusCode::OK); - let mut upgraded = hyper::upgrade::on(res).await.unwrap(); + let mut upgraded = TokioIo::new(hyper::upgrade::on(res).await.unwrap()); let mut vec = vec![]; upgraded.read_to_end(&mut vec).await.unwrap(); @@ -2264,7 +2268,7 @@ mod conn { ); } - async fn drain_til_eof(mut sock: T) -> io::Result<()> { + async fn drain_til_eof(mut sock: T) -> io::Result<()> { let mut buf = [0u8; 1024]; loop { let n = sock.read(&mut buf).await?; @@ -2276,11 +2280,11 @@ mod conn { } struct DebugStream { - tcp: TcpStream, + tcp: TokioIo, shutdown_called: bool, } - impl AsyncWrite for DebugStream { + impl hyper::rt::Write for DebugStream { fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -2305,11 +2309,11 @@ mod conn { } } - impl AsyncRead for DebugStream { + impl hyper::rt::Read for DebugStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.tcp).poll_read(cx, buf) } diff --git a/tests/server.rs b/tests/server.rs index 7a1a5dd430..98ded22d73 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -22,8 +22,8 @@ use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::rt::Timer; -use support::{TokioExecutor, TokioTimer}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use hyper::rt::{Read as AsyncRead, Write as AsyncWrite}; +use support::{TokioExecutor, TokioIo, TokioTimer}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener as TkTcpListener, TcpListener, TcpStream as TkTcpStream}; @@ -31,6 +31,7 @@ use hyper::body::{Body, Incoming as IncomingBody}; use hyper::server::conn::{http1, http2}; use hyper::service::{service_fn, Service}; use hyper::{Method, Request, Response, StatusCode, Uri, Version}; +use tokio::pin; mod support; @@ -974,6 +975,7 @@ async fn expect_continue_waits_for_body_poll() { }); let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection( @@ -1139,14 +1141,21 @@ async fn disable_keep_alive_mid_request() { let child = thread::spawn(move || { let mut req = connect(&addr); req.write_all(b"GET / HTTP/1.1\r\n").unwrap(); + thread::sleep(Duration::from_millis(10)); tx1.send(()).unwrap(); rx2.recv().unwrap(); req.write_all(b"Host: localhost\r\n\r\n").unwrap(); let mut buf = vec![]; req.read_to_end(&mut buf).unwrap(); + assert!( + buf.starts_with(b"HTTP/1.1 200 OK\r\n"), + "should receive OK response, but buf: {:?}", + buf, + ); }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let srv = http1::Builder::new().serve_connection(socket, HelloWorld); future::try_select(srv, rx1) .then(|r| match r { @@ -1194,7 +1203,7 @@ async fn disable_keep_alive_post_request() { let dropped2 = dropped.clone(); let (socket, _) = listener.accept().await.unwrap(); let transport = DebugStream { - stream: socket, + stream: TokioIo::new(socket), _debug: dropped2, }; let server = http1::Builder::new().serve_connection(transport, HelloWorld); @@ -1222,6 +1231,7 @@ async fn empty_parse_eof_does_not_return_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -1238,6 +1248,7 @@ async fn nonempty_parse_eof_returns_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -1261,6 +1272,7 @@ async fn http1_allow_half_close() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .half_close(true) .serve_connection( @@ -1288,6 +1300,7 @@ async fn disconnect_after_reading_request_before_responding() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .half_close(false) .serve_connection( @@ -1319,6 +1332,7 @@ async fn returning_1xx_response_is_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection( socket, @@ -1383,6 +1397,7 @@ async fn header_read_timeout_slow_writes() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new() .timer(TokioTimer) .header_read_timeout(Duration::from_secs(5)) @@ -1458,6 +1473,7 @@ async fn header_read_timeout_slow_writes_multiple_requests() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new() .timer(TokioTimer) .header_read_timeout(Duration::from_secs(5)) @@ -1504,6 +1520,7 @@ async fn upgrades() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new().serve_connection( socket, service_fn(|_| { @@ -1522,7 +1539,7 @@ async fn upgrades() { // wait so that we don't write until other side saw 101 response rx.await.unwrap(); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1557,6 +1574,7 @@ async fn http_connect() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new().serve_connection( socket, service_fn(|_| { @@ -1574,7 +1592,7 @@ async fn http_connect() { // wait so that we don't write until other side saw 101 response rx.await.unwrap(); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1627,6 +1645,7 @@ async fn upgrades_new() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, svc) .with_upgrades() @@ -1639,10 +1658,10 @@ async fn upgrades_new() { read_101_rx.await.unwrap(); let upgraded = on_upgrade.await.expect("on_upgrade"); - let parts = upgraded.downcast::().unwrap(); + let parts = upgraded.downcast::>().unwrap(); assert_eq!(parts.read_buf, "eagerly optimistic"); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1661,6 +1680,7 @@ async fn upgrades_ignored() { loop { let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); tokio::task::spawn(async move { http1::Builder::new() .serve_connection(socket, svc) @@ -1731,6 +1751,7 @@ async fn http_connect_new() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, svc) .with_upgrades() @@ -1743,10 +1764,10 @@ async fn http_connect_new() { read_200_rx.await.unwrap(); let upgraded = on_upgrade.await.expect("on_upgrade"); - let parts = upgraded.downcast::().unwrap(); + let parts = upgraded.downcast::>().unwrap(); assert_eq!(parts.read_buf, "eagerly optimistic"); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1792,7 +1813,7 @@ async fn h2_connect() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -1811,6 +1832,7 @@ async fn h2_connect() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -1884,7 +1906,7 @@ async fn h2_connect_multiplex() { assert!(upgrade_res.expect_err("upgrade cancelled").is_canceled()); return; } - let mut upgraded = upgrade_res.expect("upgrade successful"); + let mut upgraded = TokioIo::new(upgrade_res.expect("upgrade successful")); upgraded.write_all(b"Bread?").await.unwrap(); @@ -1920,6 +1942,7 @@ async fn h2_connect_multiplex() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -1971,7 +1994,7 @@ async fn h2_connect_large_body() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -1992,6 +2015,7 @@ async fn h2_connect_large_body() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -2042,7 +2066,7 @@ async fn h2_connect_empty_frames() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -2061,6 +2085,7 @@ async fn h2_connect_empty_frames() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -2083,6 +2108,7 @@ async fn parse_errors_send_4xx_response() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -2105,6 +2131,7 @@ async fn illegal_request_length_returns_400_response() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -2145,6 +2172,7 @@ async fn max_buf_size() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .max_buf_size(MAX) .serve_connection(socket, HelloWorld) @@ -2152,6 +2180,32 @@ async fn max_buf_size() { .expect_err("should TooLarge error"); } +#[cfg(feature = "http1")] +#[tokio::test] +async fn graceful_shutdown_before_first_request_no_block() { + let (listener, addr) = setup_tcp_listener(); + + tokio::spawn(async move { + let socket = listener.accept().await.unwrap().0; + let socket = TokioIo::new(socket); + + let future = http1::Builder::new().serve_connection(socket, HelloWorld); + pin!(future); + future.as_mut().graceful_shutdown(); + + future.await.unwrap(); + }); + + let mut stream = TkTcpStream::connect(addr).await.unwrap(); + + let mut buf = vec![]; + + tokio::time::timeout(Duration::from_secs(5), stream.read_to_end(&mut buf)) + .await + .expect("timed out waiting for graceful shutdown") + .expect("error receiving response"); +} + #[test] fn streaming_body() { use futures_util::StreamExt; @@ -2375,6 +2429,7 @@ async fn http2_keep_alive_detects_unresponsive_client() { }); let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); let err = http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2393,6 +2448,7 @@ async fn http2_keep_alive_with_responsive_client() { tokio::spawn(async move { let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2403,7 +2459,7 @@ async fn http2_keep_alive_with_responsive_client() { .expect("serve_connection"); }); - let tcp = connect_async(addr).await; + let tcp = TokioIo::new(connect_async(addr).await); let (mut client, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) .handshake(tcp) .await @@ -2456,6 +2512,7 @@ async fn http2_keep_alive_count_server_pings() { tokio::spawn(async move { let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2839,6 +2896,7 @@ impl ServeOptions { tokio::select! { res = listener.accept() => { let (stream, _) = res.unwrap(); + let stream = TokioIo::new(stream); tokio::task::spawn(async move { let msg_tx = msg_tx.clone(); @@ -2890,7 +2948,7 @@ fn has_header(msg: &str, name: &str) -> bool { msg[..n].contains(name) } -fn tcp_bind(addr: &SocketAddr) -> ::tokio::io::Result { +fn tcp_bind(addr: &SocketAddr) -> std::io::Result { let std_listener = StdTcpListener::bind(addr).unwrap(); std_listener.set_nonblocking(true).unwrap(); TcpListener::from_std(std_listener) @@ -2969,7 +3027,7 @@ impl AsyncRead for DebugStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } @@ -3026,9 +3084,11 @@ impl TestClient { let host = req.uri().host().expect("uri has no host"); let port = req.uri().port_u16().expect("uri has no port"); - let stream = TkTcpStream::connect(format!("{}:{}", host, port)) - .await - .unwrap(); + let stream = TokioIo::new( + TkTcpStream::connect(format!("{}:{}", host, port)) + .await + .unwrap(), + ); if self.http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) diff --git a/tests/support/mod.rs b/tests/support/mod.rs index e7e1e8c6bd..c46eff89ea 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -21,7 +21,7 @@ pub use hyper::{HeaderMap, StatusCode}; pub use std::net::SocketAddr; mod tokiort; -pub use tokiort::{TokioExecutor, TokioTimer}; +pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; #[allow(unused_macros)] macro_rules! t { @@ -357,6 +357,7 @@ async fn async_test(cfg: __TestConfig) { loop { let (stream, _) = listener.accept().await.expect("server error"); + let io = TokioIo::new(stream); // Move a clone into the service_fn let serve_handles = serve_handles.clone(); @@ -386,12 +387,12 @@ async fn async_test(cfg: __TestConfig) { tokio::task::spawn(async move { if http2_only { server::conn::http2::Builder::new(TokioExecutor) - .serve_connection(stream, service) + .serve_connection(io, service) .await .expect("server error"); } else { server::conn::http1::Builder::new() - .serve_connection(stream, service) + .serve_connection(io, service) .await .expect("server error"); } @@ -425,10 +426,11 @@ async fn async_test(cfg: __TestConfig) { async move { let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); let res = if http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -440,7 +442,7 @@ async fn async_test(cfg: __TestConfig) { sender.send_request(req).await.unwrap() } else { let (mut sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -508,6 +510,7 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); let service = service_fn(move |mut req| { async move { @@ -523,11 +526,12 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) let stream = TcpStream::connect(format!("{}:{}", uri, port)) .await .unwrap(); + let io = TokioIo::new(stream); let resp = if http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -540,7 +544,7 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) sender.send_request(req).await? } else { let builder = hyper::client::conn::http1::Builder::new(); - let (mut sender, conn) = builder.handshake(stream).await.unwrap(); + let (mut sender, conn) = builder.handshake(io).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -569,12 +573,12 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) if http2_only { server::conn::http2::Builder::new(TokioExecutor) - .serve_connection(stream, service) + .serve_connection(io, service) .await .unwrap(); } else { server::conn::http1::Builder::new() - .serve_connection(stream, service) + .serve_connection(io, service) .await .unwrap(); }