Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http1): Add support for writing Trailer Fields #3375

Merged
merged 5 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ where
// We assume a modern world where the remote speaks HTTP/1.1.
// If they tell us otherwise, we'll downgrade in `read_head`.
version: Version::HTTP_11,
allow_trailer_fields: false,
},
_marker: PhantomData,
}
Expand Down Expand Up @@ -264,6 +265,16 @@ where
self.state.reading = Reading::Body(Decoder::new(msg.decode));
}

if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) {
if te_value.eq_ignore_ascii_case("trailers") {
self.state.allow_trailer_fields = true;
} else {
self.state.allow_trailer_fields = false;
}
} else {
self.state.allow_trailer_fields = false;
}
hjr3 marked this conversation as resolved.
Show resolved Hide resolved

Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}

Expand Down Expand Up @@ -640,6 +651,31 @@ where
self.state.writing = state;
}

pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) {
if T::is_server() && self.state.allow_trailer_fields == false {
debug!("trailers not allowed to be sent");
return;
}
debug_assert!(self.can_write_body() && self.can_buffer_body());

match self.state.writing {
Writing::Body(ref encoder) => {
if let Some(enc_buf) =
encoder.encode_trailers(trailers, self.state.title_case_headers)
{
self.io.buffer(enc_buf);

self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
Writing::Closed
} else {
Writing::KeepAlive
};
}
}
_ => unreachable!("write_trailers invalid state: {:?}", self.state.writing),
}
}

pub(crate) fn write_body_and_end(&mut self, chunk: B) {
debug_assert!(self.can_write_body() && self.can_buffer_body());
// empty chunks should be discarded at Dispatcher level
Expand Down Expand Up @@ -842,6 +878,8 @@ struct State {
upgrade: Option<crate::upgrade::Pending>,
/// Either HTTP/1.0 or 1.1 connection
version: Version,
/// Flag to track if trailer fields are allowed to be sent
allow_trailer_fields: bool,
}

#[derive(Debug)]
Expand Down
42 changes: 24 additions & 18 deletions src/proto/h1/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,27 +351,33 @@ where
*clear_body = true;
crate::Error::new_user_body(e)
})?;
let chunk = if let Ok(data) = frame.into_data() {
data
} else {
trace!("discarding non-data frame");
continue;
};
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;

if frame.is_data() {
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;
} else {
self.conn.write_body_and_end(chunk);
}
} else {
self.conn.write_body_and_end(chunk);
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
}
} else if frame.is_trailers() {
*clear_body = true;
self.conn.write_trailers(
frame.into_trailers().unwrap_or_else(|_| unreachable!()),
);
} else {
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
trace!("discarding unknown frame");
continue;
}
} else {
*clear_body = true;
Expand Down
124 changes: 118 additions & 6 deletions src/proto/h1/encode.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
use std::collections::HashMap;
use std::fmt;
use std::io::IoSlice;

use bytes::buf::{Chain, Take};
use bytes::Buf;
use bytes::{Buf, Bytes};
use http::{
header::{
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING,
},
HeaderMap, HeaderName, HeaderValue,
};

use super::io::WriteBuf;
use super::role::{write_headers, write_headers_title_case};

type StaticBuf = &'static [u8];

Expand All @@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64);
#[derive(Debug, PartialEq, Clone)]
enum Kind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked,
Chunked(Option<Vec<HeaderValue>>),
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
Expand All @@ -45,6 +54,7 @@ enum BufKind<B> {
Limited(Take<B>),
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
ChunkedEnd(StaticBuf),
Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
}

impl Encoder {
Expand All @@ -55,7 +65,7 @@ impl Encoder {
}
}
pub(crate) fn chunked() -> Encoder {
Encoder::new(Kind::Chunked)
Encoder::new(Kind::Chunked(None))
}

pub(crate) fn length(len: u64) -> Encoder {
Expand All @@ -67,6 +77,16 @@ impl Encoder {
Encoder::new(Kind::CloseDelimited)
}

pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
match self.kind {
Kind::Chunked(_) => Encoder {
kind: Kind::Chunked(Some(trailers)),
is_last: self.is_last,
},
_ => self,
}
}

pub(crate) fn is_eof(&self) -> bool {
matches!(self.kind, Kind::Length(0))
}
Expand All @@ -89,10 +109,17 @@ impl Encoder {
}
}

pub(crate) fn is_chunked(&self) -> bool {
match self.kind {
Kind::Chunked(_) => true,
_ => false,
}
}

pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
match self.kind {
Kind::Length(0) => Ok(None),
Kind::Chunked => Ok(Some(EncodedBuf {
Kind::Chunked(_) => Ok(Some(EncodedBuf {
kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
})),
#[cfg(feature = "server")]
Expand All @@ -109,7 +136,7 @@ impl Encoder {
debug_assert!(len > 0, "encode() called with empty buf");

let kind = match self.kind {
Kind::Chunked => {
Kind::Chunked(_) => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
Expand All @@ -136,6 +163,54 @@ impl Encoder {
EncodedBuf { kind }
}

pub(crate) fn encode_trailers<B>(
hjr3 marked this conversation as resolved.
Show resolved Hide resolved
&self,
mut trailers: HeaderMap,
title_case_headers: bool,
) -> Option<EncodedBuf<B>> {
match &self.kind {
Kind::Chunked(allowed_trailer_fields) => {
let allowed_trailer_fields_map = match allowed_trailer_fields {
Some(ref allowed_trailer_fields) => {
allowed_trailer_field_map(&allowed_trailer_fields)
}
None => return None,
};
hjr3 marked this conversation as resolved.
Show resolved Hide resolved

let mut cur_name = None;
let mut allowed_trailers = HeaderMap::new();

for (opt_name, value) in trailers.drain() {
hjr3 marked this conversation as resolved.
Show resolved Hide resolved
if let Some(n) = opt_name {
cur_name = Some(n);
}
let name = cur_name.as_ref().expect("current header name");

if allowed_trailer_fields_map.contains_key(name.as_str())
&& !invalid_trailer_field(name)
{
allowed_trailers.insert(name, value);
}
}

let mut buf = Vec::new();
if title_case_headers {
write_headers_title_case(&allowed_trailers, &mut buf);
} else {
write_headers(&allowed_trailers, &mut buf);
}

Some(EncodedBuf {
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
})
}
_ => {
debug!("attempted to encode trailers for non-chunked response");
None
}
}
}

pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
where
B: Buf,
Expand All @@ -144,7 +219,7 @@ impl Encoder {
debug_assert!(len > 0, "encode() called with empty buf");

match self.kind {
Kind::Chunked => {
Kind::Chunked(_) => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
Expand Down Expand Up @@ -181,6 +256,39 @@ impl Encoder {
}
}

fn invalid_trailer_field(name: &HeaderName) -> bool {
hjr3 marked this conversation as resolved.
Show resolved Hide resolved
match name {
&AUTHORIZATION => true,
&CACHE_CONTROL => true,
&CONTENT_ENCODING => true,
&CONTENT_LENGTH => true,
&CONTENT_RANGE => true,
&CONTENT_TYPE => true,
&HOST => true,
&MAX_FORWARDS => true,
&SET_COOKIE => true,
&TRAILER => true,
&TRANSFER_ENCODING => true,
_ => false,
hjr3 marked this conversation as resolved.
Show resolved Hide resolved
}
}

fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
let mut trailer_map = HashMap::new();

for header_value in allowed_trailer_fields {
if let Ok(header_str) = header_value.to_str() {
let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();

for item in items {
trailer_map.entry(item.to_string()).or_insert(());
}
}
}

trailer_map
}

impl<B> Buf for EncodedBuf<B>
where
B: Buf,
Expand All @@ -192,6 +300,7 @@ where
BufKind::Limited(ref b) => b.remaining(),
BufKind::Chunked(ref b) => b.remaining(),
BufKind::ChunkedEnd(ref b) => b.remaining(),
BufKind::Trailers(ref b) => b.remaining(),
}
}

Expand All @@ -202,6 +311,7 @@ where
BufKind::Limited(ref b) => b.chunk(),
BufKind::Chunked(ref b) => b.chunk(),
BufKind::ChunkedEnd(ref b) => b.chunk(),
BufKind::Trailers(ref b) => b.chunk(),
}
}

Expand All @@ -212,6 +322,7 @@ where
BufKind::Limited(ref mut b) => b.advance(cnt),
BufKind::Chunked(ref mut b) => b.advance(cnt),
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
BufKind::Trailers(ref mut b) => b.advance(cnt),
}
}

Expand All @@ -222,6 +333,7 @@ where
BufKind::Limited(ref b) => b.chunks_vectored(dst),
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
BufKind::Trailers(ref b) => b.chunks_vectored(dst),
}
}
}
Expand Down
Loading