Skip to content

Commit

Permalink
feat(http1) Add support for writing Trailer Fields
Browse files Browse the repository at this point in the history
Closes #2719
  • Loading branch information
hjr3 committed Nov 4, 2023
1 parent 04f0981 commit 03f9328
Show file tree
Hide file tree
Showing 8 changed files with 473 additions and 30 deletions.
38 changes: 38 additions & 0 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ where
// We assume a modern world where the remote speaks HTTP/1.1.
// If they tell us otherwise, we'll downgrade in `read_head`.
version: Version::HTTP_11,
allow_trailer_fields: false,
},
_marker: PhantomData,
}
Expand Down Expand Up @@ -264,6 +265,16 @@ where
self.state.reading = Reading::Body(Decoder::new(msg.decode));
}

if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) {
if te_value.eq_ignore_ascii_case("trailers") {
self.state.allow_trailer_fields = true;
} else {
self.state.allow_trailer_fields = false;
}
} else {
self.state.allow_trailer_fields = false;
}

Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}

Expand Down Expand Up @@ -640,6 +651,31 @@ where
self.state.writing = state;
}

pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) {
if T::is_server() && self.state.allow_trailer_fields == false {
debug!("trailers not allowed to be sent");
return;
}
debug_assert!(self.can_write_body() && self.can_buffer_body());

match self.state.writing {
Writing::Body(ref encoder) => {
if let Some(enc_buf) =
encoder.encode_trailers(trailers, self.state.title_case_headers)
{
self.io.buffer(enc_buf);

self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
Writing::Closed
} else {
Writing::KeepAlive
};
}
}
_ => unreachable!("write_trailers invalid state: {:?}", self.state.writing),
}
}

pub(crate) fn write_body_and_end(&mut self, chunk: B) {
debug_assert!(self.can_write_body() && self.can_buffer_body());
// empty chunks should be discarded at Dispatcher level
Expand Down Expand Up @@ -842,6 +878,8 @@ struct State {
upgrade: Option<crate::upgrade::Pending>,
/// Either HTTP/1.0 or 1.1 connection
version: Version,
/// Flag to track if trailer fields are allowed to be sent
allow_trailer_fields: bool,
}

#[derive(Debug)]
Expand Down
42 changes: 24 additions & 18 deletions src/proto/h1/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,27 +351,33 @@ where
*clear_body = true;
crate::Error::new_user_body(e)
})?;
let chunk = if let Ok(data) = frame.into_data() {
data
} else {
trace!("discarding non-data frame");
continue;
};
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;

if frame.is_data() {
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;
} else {
self.conn.write_body_and_end(chunk);
}
} else {
self.conn.write_body_and_end(chunk);
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
}
} else if frame.is_trailers() {
*clear_body = true;
self.conn.write_trailers(
frame.into_trailers().unwrap_or_else(|_| unreachable!()),
);
} else {
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
trace!("discarding unknown frame");
continue;
}
} else {
*clear_body = true;
Expand Down
124 changes: 118 additions & 6 deletions src/proto/h1/encode.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
use std::collections::HashMap;
use std::fmt;
use std::io::IoSlice;

use bytes::buf::{Chain, Take};
use bytes::Buf;
use bytes::{Buf, Bytes};
use http::{
header::{
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING,
},
HeaderMap, HeaderName, HeaderValue,
};

use super::io::WriteBuf;
use super::role::{write_headers, write_headers_title_case};

type StaticBuf = &'static [u8];

Expand All @@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64);
#[derive(Debug, PartialEq, Clone)]
enum Kind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked,
Chunked(Option<Vec<HeaderValue>>),
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
Expand All @@ -45,6 +54,7 @@ enum BufKind<B> {
Limited(Take<B>),
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
ChunkedEnd(StaticBuf),
Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
}

impl Encoder {
Expand All @@ -55,7 +65,7 @@ impl Encoder {
}
}
pub(crate) fn chunked() -> Encoder {
Encoder::new(Kind::Chunked)
Encoder::new(Kind::Chunked(None))
}

pub(crate) fn length(len: u64) -> Encoder {
Expand All @@ -67,6 +77,16 @@ impl Encoder {
Encoder::new(Kind::CloseDelimited)
}

pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
match self.kind {
Kind::Chunked(_) => Encoder {
kind: Kind::Chunked(Some(trailers)),
is_last: self.is_last,
},
_ => self,
}
}

pub(crate) fn is_eof(&self) -> bool {
matches!(self.kind, Kind::Length(0))
}
Expand All @@ -89,10 +109,17 @@ impl Encoder {
}
}

pub(crate) fn is_chunked(&self) -> bool {
match self.kind {
Kind::Chunked(_) => true,
_ => false,
}
}

pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
match self.kind {
Kind::Length(0) => Ok(None),
Kind::Chunked => Ok(Some(EncodedBuf {
Kind::Chunked(_) => Ok(Some(EncodedBuf {
kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
})),
#[cfg(feature = "server")]
Expand All @@ -109,7 +136,7 @@ impl Encoder {
debug_assert!(len > 0, "encode() called with empty buf");

let kind = match self.kind {
Kind::Chunked => {
Kind::Chunked(_) => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
Expand All @@ -136,6 +163,54 @@ impl Encoder {
EncodedBuf { kind }
}

pub(crate) fn encode_trailers<B>(
&self,
mut trailers: HeaderMap,
title_case_headers: bool,
) -> Option<EncodedBuf<B>> {
match &self.kind {
Kind::Chunked(allowed_trailer_fields) => {
let allowed_trailer_fields_map = match allowed_trailer_fields {
Some(ref allowed_trailer_fields) => {
allowed_trailer_field_map(&allowed_trailer_fields)
}
None => return None,
};

let mut cur_name = None;
let mut allowed_trailers = HeaderMap::new();

for (opt_name, value) in trailers.drain() {
if let Some(n) = opt_name {
cur_name = Some(n);
}
let name = cur_name.as_ref().expect("current header name");

if allowed_trailer_fields_map.contains_key(name.as_str())
&& !invalid_trailer_field(name)
{
allowed_trailers.insert(name, value);
}
}

let mut buf = Vec::new();
if title_case_headers {
write_headers_title_case(&allowed_trailers, &mut buf);
} else {
write_headers(&allowed_trailers, &mut buf);
}

Some(EncodedBuf {
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
})
}
_ => {
debug!("attempted to encode trailers for non-chunked response");
None
}
}
}

pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
where
B: Buf,
Expand All @@ -144,7 +219,7 @@ impl Encoder {
debug_assert!(len > 0, "encode() called with empty buf");

match self.kind {
Kind::Chunked => {
Kind::Chunked(_) => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
Expand Down Expand Up @@ -181,6 +256,39 @@ impl Encoder {
}
}

fn invalid_trailer_field(name: &HeaderName) -> bool {
match name {
&AUTHORIZATION => true,
&CACHE_CONTROL => true,
&CONTENT_ENCODING => true,
&CONTENT_LENGTH => true,
&CONTENT_RANGE => true,
&CONTENT_TYPE => true,
&HOST => true,
&MAX_FORWARDS => true,
&SET_COOKIE => true,
&TRAILER => true,
&TRANSFER_ENCODING => true,
_ => false,
}
}

fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
let mut trailer_map = HashMap::new();

for header_value in allowed_trailer_fields {
if let Ok(header_str) = header_value.to_str() {
let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();

for item in items {
trailer_map.entry(item.to_string()).or_insert(());
}
}
}

trailer_map
}

impl<B> Buf for EncodedBuf<B>
where
B: Buf,
Expand All @@ -192,6 +300,7 @@ where
BufKind::Limited(ref b) => b.remaining(),
BufKind::Chunked(ref b) => b.remaining(),
BufKind::ChunkedEnd(ref b) => b.remaining(),
BufKind::Trailers(ref b) => b.remaining(),
}
}

Expand All @@ -202,6 +311,7 @@ where
BufKind::Limited(ref b) => b.chunk(),
BufKind::Chunked(ref b) => b.chunk(),
BufKind::ChunkedEnd(ref b) => b.chunk(),
BufKind::Trailers(ref b) => b.chunk(),
}
}

Expand All @@ -212,6 +322,7 @@ where
BufKind::Limited(ref mut b) => b.advance(cnt),
BufKind::Chunked(ref mut b) => b.advance(cnt),
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
BufKind::Trailers(ref mut b) => b.advance(cnt),
}
}

Expand All @@ -222,6 +333,7 @@ where
BufKind::Limited(ref b) => b.chunks_vectored(dst),
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
BufKind::Trailers(ref b) => b.chunks_vectored(dst),
}
}
}
Expand Down
Loading

0 comments on commit 03f9328

Please sign in to comment.