Skip to content

Commit

Permalink
feat(http1) Add support for Trailer Fields
Browse files Browse the repository at this point in the history
Closes #2719
  • Loading branch information
hjr3 committed Oct 26, 2023
1 parent 04f0981 commit 60eaca9
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 21 deletions.
25 changes: 25 additions & 0 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ where
self.state.reading = Reading::Body(Decoder::new(msg.decode));
}

if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) {
if te_value.eq_ignore_ascii_case("trailers") {
wants = wants.add(Wants::TRAILERS);
}
}

Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}

Expand Down Expand Up @@ -640,6 +646,25 @@ where
self.state.writing = state;
}

pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) {
debug_assert!(self.can_write_body() && self.can_buffer_body());

match self.state.writing {
Writing::Body(ref encoder) => {
if let Some(enc_buf) = encoder.encode_trailers(trailers) {
self.io.buffer(enc_buf);

self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
Writing::Closed
} else {
Writing::KeepAlive
};
}
}
_ => unreachable!("write_trailers invalid state: {:?}", self.state.writing),
}
}

pub(crate) fn write_body_and_end(&mut self, chunk: B) {
debug_assert!(self.can_write_body() && self.can_buffer_body());
// empty chunks should be discarded at Dispatcher level
Expand Down
42 changes: 24 additions & 18 deletions src/proto/h1/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,27 +351,33 @@ where
*clear_body = true;
crate::Error::new_user_body(e)
})?;
let chunk = if let Ok(data) = frame.into_data() {
data
} else {
trace!("discarding non-data frame");
continue;
};
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;

if frame.is_data() {
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;
} else {
self.conn.write_body_and_end(chunk);
}
} else {
self.conn.write_body_and_end(chunk);
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
}
} else if frame.is_trailers() {
*clear_body = true;
self.conn.write_trailers(
frame.into_trailers().unwrap_or_else(|_| unreachable!()),
);
} else {
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
trace!("discarding unknown frame");
continue;
}
} else {
*clear_body = true;
Expand Down
38 changes: 37 additions & 1 deletion src/proto/h1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::fmt;
use std::io::IoSlice;

use bytes::buf::{Chain, Take};
use bytes::Buf;
use bytes::{Buf, Bytes};
use http::HeaderMap;

use super::io::WriteBuf;

Expand Down Expand Up @@ -45,6 +46,7 @@ enum BufKind<B> {
Limited(Take<B>),
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
ChunkedEnd(StaticBuf),
Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
}

impl Encoder {
Expand Down Expand Up @@ -136,6 +138,20 @@ impl Encoder {
EncodedBuf { kind }
}

pub(crate) fn encode_trailers<B>(&self, trailers: HeaderMap) -> Option<EncodedBuf<B>> {
match self.kind {
Kind::Chunked => {
let mut buf = Vec::new();
write_headers(&trailers, &mut buf);

Some(EncodedBuf {
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
})
}
_ => None, // silently discard trailers
}
}

pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
where
B: Buf,
Expand Down Expand Up @@ -181,6 +197,22 @@ impl Encoder {
}
}

// FIXME: dry up
fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
for (name, value) in headers {
extend(dst, name.as_str().as_bytes());
extend(dst, b": ");
extend(dst, value.as_bytes());
extend(dst, b"\r\n");
}
}

#[inline]
fn extend(dst: &mut Vec<u8>, data: &[u8]) {
dst.extend_from_slice(data);
}
// end FIXME: dry up

impl<B> Buf for EncodedBuf<B>
where
B: Buf,
Expand All @@ -192,6 +224,7 @@ where
BufKind::Limited(ref b) => b.remaining(),
BufKind::Chunked(ref b) => b.remaining(),
BufKind::ChunkedEnd(ref b) => b.remaining(),
BufKind::Trailers(ref b) => b.remaining(),
}
}

Expand All @@ -202,6 +235,7 @@ where
BufKind::Limited(ref b) => b.chunk(),
BufKind::Chunked(ref b) => b.chunk(),
BufKind::ChunkedEnd(ref b) => b.chunk(),
BufKind::Trailers(ref b) => b.chunk(),
}
}

Expand All @@ -212,6 +246,7 @@ where
BufKind::Limited(ref mut b) => b.advance(cnt),
BufKind::Chunked(ref mut b) => b.advance(cnt),
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
BufKind::Trailers(ref mut b) => b.advance(cnt),
}
}

Expand All @@ -222,6 +257,7 @@ where
BufKind::Limited(ref b) => b.chunks_vectored(dst),
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
BufKind::Trailers(ref b) => b.chunks_vectored(dst),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/proto/h1/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ impl Wants {
const EMPTY: Wants = Wants(0b00);
const EXPECT: Wants = Wants(0b01);
const UPGRADE: Wants = Wants(0b10);
const TRAILERS: Wants = Wants(0b100);

#[must_use]
fn add(self, other: Wants) -> Wants {
Expand Down
46 changes: 45 additions & 1 deletion tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::convert::Infallible;
use std::fmt;
use std::future::Future;
use std::io::{Read, Write};
use std::iter::FromIterator;
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::thread;
Expand All @@ -13,7 +14,7 @@ use std::time::Duration;
use http::uri::PathAndQuery;
use http_body_util::{BodyExt, StreamBody};
use hyper::body::Frame;
use hyper::header::HeaderValue;
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use hyper::{Method, Request, StatusCode, Uri, Version};

use bytes::Bytes;
Expand Down Expand Up @@ -408,6 +409,15 @@ macro_rules! __client_req_prop {
Frame::data,
)));
}};

($req_builder:ident, $body:ident, $addr:ident, body_stream_with_trailers: $body_e:expr) => {{
use support::trailers::StreamBodyWithTrailers;
let (body, trailers) = $body_e;
$body = BodyExt::boxed(StreamBodyWithTrailers::with_trailers(
futures_util::TryStreamExt::map_ok(body, Frame::data),
trailers,
));
}};
}

macro_rules! __client_req_header {
Expand Down Expand Up @@ -631,6 +641,40 @@ test! {
body: &b"hello"[..],
}

test! {
name: client_post_req_body_chunked_with_trailer,

server:
expected: "\
POST / HTTP/1.1\r\n\
host: {addr}\r\n\
transfer-encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
chunky-trailer: header data\r\n\
\r\n\
",
reply: REPLY_OK,

client:
request: {
method: POST,
url: "http://{addr}/",
body_stream_with_trailers: (
(futures_util::stream::once(async { Ok::<_, Infallible>(Bytes::from("hello"))})),
HeaderMap::from_iter(vec![(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data")
)].into_iter())),
},
response:
status: OK,
headers: {},
body: None,
}

test! {
name: client_get_req_body_sized,

Expand Down
57 changes: 56 additions & 1 deletion tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use futures_channel::oneshot;
use futures_util::future::{self, Either, FutureExt};
use h2::client::SendRequest;
use h2::{RecvStream, SendStream};
use http::header::{HeaderName, HeaderValue};
use http::header::{HeaderMap, HeaderName, HeaderValue};
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
use hyper::rt::Timer;
use hyper::rt::{Read as AsyncRead, Write as AsyncWrite};
Expand Down Expand Up @@ -2595,6 +2595,48 @@ async fn http2_keep_alive_count_server_pings() {
.expect("timed out waiting for pings");
}

#[test]
fn http1_trailer_headers() {
let body = futures_util::stream::once(async move { Ok("hello".into()) });
let mut headers = HeaderMap::new();
headers.insert("chunky-trailer", "header data".parse().unwrap());

let server = serve();
server
.reply()
.header("transfer-encoding", "chunked")
.header("trailer", "chunky-trailer")
.body_stream_with_trailers(body, headers);
let mut req = connect(server.addr());
req.write_all(
b"\
GET / HTTP/1.1\r\n\
Host: example.domain\r\n\
Connection: keep-alive\r\n\
TE: trailers\r\n\
\r\n\
",
)
.expect("writing");

let chunky_trailer_chunk = b"\r\nchunky-trailer: header data\r\n\r\n";
let res = read_until(&mut req, |buf| buf.ends_with(chunky_trailer_chunk)).expect("reading");
let sres = s(&res);
dbg!(&sres);

let expected_head =
"HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n";
assert_eq!(&sres[..expected_head.len()], expected_head);

// skip the date header
let date_fragment = "GMT\r\n\r\n";
let pos = sres.find(date_fragment).expect("find GMT");
let body = &sres[pos + date_fragment.len()..];

let expected_body = "5\r\nhello\r\n0\r\nchunky-trailer: header data\r\n\r\n";
assert_eq!(body, expected_body);
}

// -------------------------------------------------
// the Server that is used to run all the tests with
// -------------------------------------------------
Expand Down Expand Up @@ -2700,6 +2742,19 @@ impl<'a> ReplyBuilder<'a> {
self.tx.lock().unwrap().send(Reply::Body(body)).unwrap();
}

fn body_stream_with_trailers<S>(self, stream: S, trailers: HeaderMap)
where
S: futures_util::Stream<Item = Result<Bytes, BoxError>> + Send + Sync + 'static,
{
use futures_util::TryStreamExt;
use hyper::body::Frame;
use support::trailers::StreamBodyWithTrailers;
let mut stream_body = StreamBodyWithTrailers::new(stream.map_ok(Frame::data));
stream_body.set_trailers(trailers);
let body = BodyExt::boxed(stream_body);
self.tx.lock().unwrap().send(Reply::Body(body)).unwrap();
}

#[allow(dead_code)]
fn error<E: Into<BoxError>>(self, err: E) {
self.tx
Expand Down
2 changes: 2 additions & 0 deletions tests/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub use std::net::SocketAddr;
mod tokiort;
pub use tokiort::{TokioExecutor, TokioIo, TokioTimer};

pub mod trailers;

#[allow(unused_macros)]
macro_rules! t {
(
Expand Down
Loading

0 comments on commit 60eaca9

Please sign in to comment.