Skip to content

Commit

Permalink
feat: add new API to close the client
Browse files Browse the repository at this point in the history
  • Loading branch information
dodomorandi committed Apr 23, 2024
1 parent 8c92e8e commit 4ea7a9f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 25 deletions.
75 changes: 65 additions & 10 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tracing::trace;

static VERSION_RE: Lazy<Regex> =
Expand Down Expand Up @@ -72,18 +73,46 @@ pub struct Client {
inbox_prefix: Arc<str>,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
inner: Arc<std::sync::Mutex<Option<Inner>>>,
}

impl Client {
pub(crate) fn new(
info: tokio::sync::watch::Receiver<ServerInfo>,
state: tokio::sync::watch::Receiver<State>,
sender: mpsc::Sender<Command>,
capacity: usize,
inbox_prefix: String,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
) -> Client {
#[derive(Debug)]
struct Inner {
events_handler_task: JoinHandle<()>,
connection_handler_task: JoinHandle<()>,
}

pub(super) struct Builder {
pub(super) info: tokio::sync::watch::Receiver<ServerInfo>,
pub(super) state: tokio::sync::watch::Receiver<State>,
pub(super) sender: mpsc::Sender<Command>,
pub(super) capacity: usize,
pub(super) inbox_prefix: String,
pub(super) request_timeout: Option<Duration>,
pub(super) max_payload: Arc<AtomicUsize>,
pub(super) events_handler_task: JoinHandle<()>,
pub(super) connection_handler_task: JoinHandle<()>,
}

impl Builder {
pub(super) fn build(self) -> Client {
let Self {
info,
state,
sender,
capacity,
inbox_prefix,
request_timeout,
max_payload,
events_handler_task,
connection_handler_task,
} = self;

let inner = Arc::new(std::sync::Mutex::new(Some(Inner {
events_handler_task,
connection_handler_task,
})));

Client {
info,
state,
Expand All @@ -93,9 +122,12 @@ impl Client {
inbox_prefix: inbox_prefix.into(),
request_timeout,
max_payload,
inner,
}
}
}

impl Client {
/// Returns last received info from the server.
///
/// # Examples
Expand Down Expand Up @@ -607,6 +639,29 @@ impl Client {
.await
.map_err(Into::into)
}

/// Stops the NATS client
///
/// # Panics
///
/// Panics if the client has been already closed or if one of the tasks involved in events and
/// connection handling panic.
pub async fn stop(self) {
self.sender
.send(Command::Close)
.await
.expect("client already closed");

let inner = self
.inner
.lock()
.unwrap()
.take()
.expect("client already closed");

inner.connection_handler_task.await.unwrap();
inner.events_handler_task.await.unwrap();
}
}

/// Used for building customized requests.
Expand Down
43 changes: 28 additions & 15 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,13 @@ pub(crate) enum Command {
observer: oneshot::Sender<()>,
},
Reconnect,
Close,
}

#[derive(Debug)]
enum ShouldClose {
Yes,
No,
}

/// `ClientOp` represents all actions of `Client`.
Expand Down Expand Up @@ -542,7 +549,9 @@ impl ConnectionHandler {
made_progress = true;

for cmd in recv_buf.drain(..) {
handler.handle_command(cmd);
if matches!(handler.handle_command(cmd), ShouldClose::Yes) {
return Poll::Ready(ExitReason::Closed);
}
}
}
// TODO: replace `_` with `0` after bumping MSRV to 1.75
Expand Down Expand Up @@ -731,7 +740,7 @@ impl ConnectionHandler {
}
}

fn handle_command(&mut self, command: Command) {
fn handle_command(&mut self, command: Command) -> ShouldClose {
self.ping_interval.reset();

match command {
Expand Down Expand Up @@ -836,7 +845,10 @@ impl ConnectionHandler {
Command::Reconnect => {
self.should_reconnect = true;
}

Command::Close => return ShouldClose::Yes,
}
ShouldClose::No
}

async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
Expand Down Expand Up @@ -938,17 +950,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
let (sender, mut receiver) = mpsc::channel(options.sender_capacity);

let client = Client::new(
info_watcher,
state_rx,
sender,
options.subscription_capacity,
options.inbox_prefix,
options.request_timeout,
max_payload,
);

task::spawn(async move {
let events_handler_task = task::spawn(async move {
while let Some(event) = events_rx.recv().await {
tracing::info!("event: {}", event);
if let Some(event_callback) = &options.event_callback {
Expand All @@ -957,7 +959,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
}
});

task::spawn(async move {
let connection_handler_task = task::spawn(async move {
if connection.is_none() && options.retry_on_initial_connect {
let (info, connection_ok) = match connector.connect().await {
Ok((info, connection)) => (info, connection),
Expand All @@ -975,7 +977,18 @@ pub async fn connect_with_options<A: ToServerAddrs>(
connection_handler.process(&mut receiver).await
});

Ok(client)
Ok(client::Builder {
info: info_watcher,
state: state_rx,
sender,
capacity: options.subscription_capacity,
inbox_prefix: options.inbox_prefix,
request_timeout: options.request_timeout,
max_payload,
events_handler_task,
connection_handler_task,
}
.build())
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down

0 comments on commit 4ea7a9f

Please sign in to comment.