Skip to content

Commit

Permalink
zstd compression
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini committed Jun 16, 2024
1 parent fb27536 commit d89d4fc
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 3 deletions.
2 changes: 2 additions & 0 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ ring = { version = "0.17", optional = true }
rand = "0.8"
webpki = { package = "rustls-webpki", version = "0.102" }
portable-atomic = "1"
async-compression = { version = "0.4.11", features = ["tokio"], optional = true }

[dev-dependencies]
ring = "0.17"
Expand All @@ -65,6 +66,7 @@ service = []
aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs"]
ring = ["dep:ring", "tokio-rustls/ring"]
fips = ["aws-lc-rs", "tokio-rustls/fips"]
zstd = ["dep:async-compression", "async-compression/zstd"]
# All experimental features are part of this feature flag.
experimental = ["service"]
# Used for enabling/disabling tests that by design take a lot of time to complete.
Expand Down
1 change: 1 addition & 0 deletions async-nats/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,7 @@ mod write_op {
auth_token: None,
headers: false,
no_responders: false,
m4ss_zstd: false,
})]
.iter(),
)
Expand Down
171 changes: 168 additions & 3 deletions async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ impl<const CLIENT_NODE: bool> Handler<CLIENT_NODE> {
echo: !self.options.no_echo,
headers: true,
no_responders: true,
m4ss_zstd: server_info.m4ss_zstd,
};

if let Some(nkey) = self.options.auth.nkey.as_ref() {
Expand Down Expand Up @@ -399,10 +400,174 @@ impl<const CLIENT_NODE: bool> Handler<CLIENT_NODE> {
connect_info.nkey = auth.nkey;
}

#[cfg(feature = "zstd")]
let m4ss_zstd = connect_info.m4ss_zstd;

connection
.easy_write_and_flush([ClientOp::Connect(connect_info)].iter())
.await
.map_err(E::WriteStream)?;

#[cfg(feature = "zstd")]
if m4ss_zstd {
use std::pin::Pin;
use std::task::{Context, Poll, Waker};

use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};

#[derive(Debug)]
struct Maybe<T> {
item: Option<T>,
waker: Option<Waker>,
}

impl<T> Maybe<T> {
fn new(item: Option<T>) -> Self {
Self { item, waker: None }
}

fn take_item(&mut self) -> Option<T> {
self.item.take()
}

fn set_item(&mut self, item: T) {
self.item = Some(item);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}

impl<T: AsyncRead + Unpin> AsyncRead for Maybe<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut self.item {
Some(item) => Pin::new(item).poll_read(cx, buf),
None => {
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}

impl<T: AsyncWrite + Unpin> AsyncWrite for Maybe<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut self.item {
Some(item) => Pin::new(item).poll_write(cx, buf),
None => {
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match &mut self.item {
Some(item) => Pin::new(item).poll_flush(cx),
None => {
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}

let stream = connection.stream;

let decompressor = async_compression::tokio::bufread::ZstdDecoder::new(
BufReader::new(Maybe::new(Some(stream))),
);
let compressor =
async_compression::tokio::write::ZstdEncoder::new(Maybe::new(None));

struct CompressedStream<S> {
decompressor:
async_compression::tokio::bufread::ZstdDecoder<BufReader<Maybe<S>>>,
compressor: async_compression::tokio::write::ZstdEncoder<Maybe<S>>,
}

impl<S> AsyncRead for CompressedStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(stream) = self.compressor.get_mut().take_item() {
self.decompressor.get_mut().get_mut().set_item(stream);
}

Pin::new(&mut self.decompressor).poll_read(cx, buf)
}
}

impl<S> AsyncWrite for CompressedStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(stream) =
self.decompressor.get_mut().get_mut().take_item()
{
self.compressor.get_mut().set_item(stream);
}

Pin::new(&mut self.compressor).poll_write(cx, buf)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
if let Some(stream) =
self.decompressor.get_mut().get_mut().take_item()
{
self.compressor.get_mut().set_item(stream);
}

Pin::new(&mut self.compressor).poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}

connection.stream = Box::new(CompressedStream {
decompressor,
compressor,
});
}

connection
.easy_write_and_flush(
[ClientOp::Connect(connect_info), ClientOp::Ping].iter(),
)
.easy_write_and_flush([ClientOp::Ping].iter())
.await
.map_err(E::WriteStream)?;

Expand Down
9 changes: 9 additions & 0 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ pub struct ServerInfo {
/// Whether server goes into lame duck mode.
#[serde(default, rename = "ldm")]
pub lame_duck_mode: bool,
#[serde(default)]
pub m4ss_zstd: bool,
}

#[derive(Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -1478,6 +1480,13 @@ pub struct ConnectInfo {

/// Whether the client supports no_responders.
pub no_responders: bool,

#[serde(skip_serializing_if = "is_default")]
pub m4ss_zstd: bool,
}

fn is_default(m4ss_zstd: &bool) -> bool {
!*m4ss_zstd
}

/// Protocol version used by the client.
Expand Down

0 comments on commit d89d4fc

Please sign in to comment.