diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index 6e5da265c..1a7351f56 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -40,6 +40,7 @@ ring = { version = "0.17", optional = true } rand = "0.8" webpki = { package = "rustls-webpki", version = "0.102" } portable-atomic = "1" +async-compression = { version = "0.4.11", features = ["tokio"], optional = true } [dev-dependencies] ring = "0.17" @@ -65,6 +66,7 @@ service = [] aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs"] ring = ["dep:ring", "tokio-rustls/ring"] fips = ["aws-lc-rs", "tokio-rustls/fips"] +zstd = ["dep:async-compression", "async-compression/zstd"] # All experimental features are part of this feature flag. experimental = ["service"] # Used for enabling/disabling tests that by design take a lot of time to complete. diff --git a/async-nats/src/connection.rs b/async-nats/src/connection.rs index 66fa28bb3..6154640bd 100644 --- a/async-nats/src/connection.rs +++ b/async-nats/src/connection.rs @@ -1329,6 +1329,7 @@ mod write_op { auth_token: None, headers: false, no_responders: false, + m4ss_zstd: false, })] .iter(), ) diff --git a/async-nats/src/connector.rs b/async-nats/src/connector.rs index 9d105d109..63f009eea 100644 --- a/async-nats/src/connector.rs +++ b/async-nats/src/connector.rs @@ -357,6 +357,7 @@ impl Handler { echo: !self.options.no_echo, headers: true, no_responders: true, + m4ss_zstd: server_info.m4ss_zstd, }; if let Some(nkey) = self.options.auth.nkey.as_ref() { @@ -399,10 +400,174 @@ impl Handler { connect_info.nkey = auth.nkey; } + #[cfg(feature = "zstd")] + let m4ss_zstd = connect_info.m4ss_zstd; + + connection + .easy_write_and_flush([ClientOp::Connect(connect_info)].iter()) + .await + .map_err(E::WriteStream)?; + + #[cfg(feature = "zstd")] + if m4ss_zstd { + use std::pin::Pin; + use std::task::{Context, Poll, Waker}; + + use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf}; + + #[derive(Debug)] + struct Maybe { + item: Option, + waker: Option, + } + + impl Maybe { + fn new(item: Option) -> Self { + Self { item, waker: None } + } + + fn take_item(&mut self) -> Option { + self.item.take() + } + + fn set_item(&mut self, item: T) { + self.item = Some(item); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + } + + impl AsyncRead for Maybe { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut self.item { + Some(item) => Pin::new(item).poll_read(cx, buf), + None => { + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } + } + + impl AsyncWrite for Maybe { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut self.item { + Some(item) => Pin::new(item).poll_write(cx, buf), + None => { + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match &mut self.item { + Some(item) => Pin::new(item).poll_flush(cx), + None => { + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + let stream = connection.stream; + + let decompressor = async_compression::tokio::bufread::ZstdDecoder::new( + BufReader::new(Maybe::new(Some(stream))), + ); + let compressor = + async_compression::tokio::write::ZstdEncoder::new(Maybe::new(None)); + + struct CompressedStream { + decompressor: + async_compression::tokio::bufread::ZstdDecoder>>, + compressor: async_compression::tokio::write::ZstdEncoder>, + } + + impl AsyncRead for CompressedStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(stream) = self.compressor.get_mut().take_item() { + self.decompressor.get_mut().get_mut().set_item(stream); + } + + Pin::new(&mut self.decompressor).poll_read(cx, buf) + } + } + + impl AsyncWrite for CompressedStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(stream) = + self.decompressor.get_mut().get_mut().take_item() + { + self.compressor.get_mut().set_item(stream); + } + + Pin::new(&mut self.compressor).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(stream) = + self.decompressor.get_mut().get_mut().take_item() + { + self.compressor.get_mut().set_item(stream); + } + + Pin::new(&mut self.compressor).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + connection.stream = Box::new(CompressedStream { + decompressor, + compressor, + }); + } + connection - .easy_write_and_flush( - [ClientOp::Connect(connect_info), ClientOp::Ping].iter(), - ) + .easy_write_and_flush([ClientOp::Ping].iter()) .await .map_err(E::WriteStream)?; diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 3aa1ea1d1..f76beff5b 100755 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -323,6 +323,8 @@ pub struct ServerInfo { /// Whether server goes into lame duck mode. #[serde(default, rename = "ldm")] pub lame_duck_mode: bool, + #[serde(default)] + pub m4ss_zstd: bool, } #[derive(Clone, Debug, Eq, PartialEq)] @@ -1478,6 +1480,13 @@ pub struct ConnectInfo { /// Whether the client supports no_responders. pub no_responders: bool, + + #[serde(skip_serializing_if = "is_default")] + pub m4ss_zstd: bool, +} + +fn is_default(m4ss_zstd: &bool) -> bool { + !*m4ss_zstd } /// Protocol version used by the client.