Skip to content

Commit

Permalink
feat: expose connector functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dodomorandi committed Apr 26, 2024
1 parent 20bdd21 commit d5c7a5e
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 61 deletions.
2 changes: 1 addition & 1 deletion async-nats/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl Display for State {
}

/// A framed connection
pub(crate) struct Connection {
pub struct Connection {
pub(crate) stream: Box<dyn AsyncReadWrite>,
read_buf: BytesMut,
write_buf: VecDeque<Bytes>,
Expand Down
155 changes: 148 additions & 7 deletions async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ use crate::auth::Auth;
use crate::connection::Connection;
use crate::connection::ReadOpError;
use crate::connection::State;
use crate::handle_events;
use crate::options::CallbackArg1;
use crate::tls;
use crate::AuthError;
use crate::ClientOp;
use crate::ConnectInfo;
use crate::ConnectOptions;
use crate::Event;
use crate::MaybeArc;
use crate::Protocol;
Expand All @@ -33,6 +35,7 @@ use crate::LANG;
use crate::VERSION;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::engine::Engine;
use futures::FutureExt;
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::cmp;
Expand All @@ -41,14 +44,16 @@ use std::fmt;
use std::fmt::Display;
use std::io;
use std::path::PathBuf;
use std::pin::pin;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time::sleep;
use tokio_rustls::rustls;

pub(crate) struct ConnectorOptions {
pub struct ConnectorOptions {
pub(crate) tls_required: bool,
pub(crate) certificates: Vec<PathBuf>,
pub(crate) client_cert: Option<PathBuf>,
Expand All @@ -68,7 +73,7 @@ pub(crate) struct ConnectorOptions {
}

/// Maintains a list of servers and establishes connections.
pub(crate) struct Connector {
pub(crate) struct Handler<const CLIENT_NODE: bool = true> {
/// A map of servers and number of connect attempts.
servers: Vec<(ServerAddr, usize)>,
options: ConnectorOptions,
Expand All @@ -88,17 +93,148 @@ pub(crate) fn reconnect_delay_callback_default(attempts: usize) -> Duration {
}
}

impl Connector {
pub(crate) fn new<A: ToServerAddrs>(
pub struct Connector<const CLIENT_NODE: bool = true> {
pub(crate) handler: Handler<CLIENT_NODE>,
pub(crate) events_rx: tokio::sync::mpsc::Receiver<Event>,
pub(crate) state_rx: tokio::sync::watch::Receiver<State>,
pub(crate) subscription_capacity: usize,
pub(crate) event_callback: Option<CallbackArg1<Event, ()>>,
pub(crate) inbox_prefix: String,
pub(crate) request_timeout: Option<Duration>,
pub(crate) retry_on_initial_connect: bool,
pub(crate) sender_capacity: usize,
pub(crate) ping_interval: Duration,
}

pub fn create<A: ToServerAddrs>(addrs: A, options: ConnectOptions) -> Result<Connector, io::Error> {
create_inner::<A, true>(addrs, options)
}

pub fn create_leaf_connector<A: ToServerAddrs>(
addrs: A,
options: ConnectOptions,
) -> Result<Connector<false>, io::Error> {
create_inner::<A, false>(addrs, options)
}

fn create_inner<A: ToServerAddrs, const CLIENT_NODE: bool>(
addrs: A,
options: ConnectOptions,
) -> Result<Connector<CLIENT_NODE>, io::Error> {
let ConnectOptions {
name,
no_echo,
max_reconnects,
connection_timeout,
auth,
tls_required,
tls_first,
certificates,
client_cert,
client_key,
tls_client_config,
ping_interval,
subscription_capacity,
sender_capacity,
event_callback,
inbox_prefix,
request_timeout,
retry_on_initial_connect,
ignore_discovered_servers,
retain_servers_order,
read_buffer_capacity,
reconnect_delay_callback,
auth_callback,
} = options;

let options = ConnectorOptions {
tls_required,
certificates,
client_cert,
client_key,
tls_client_config,
tls_first,
auth,
no_echo,
connection_timeout,
name,
ignore_discovered_servers,
retain_servers_order,
read_buffer_capacity,
reconnect_delay_callback,
auth_callback,
max_reconnects,
};

let (events_tx, events_rx) = mpsc::channel(128);
let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
// We're setting it to the default server payload size.
let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));

let handler = Handler::<CLIENT_NODE>::new(addrs, options, events_tx, state_tx, max_payload)?;
Ok(Connector {
handler,
events_rx,
state_rx,
subscription_capacity,
inbox_prefix,
request_timeout,
event_callback,
retry_on_initial_connect,
sender_capacity,
ping_interval,
})
}

impl<const CLIENT_NODE: bool> Connector<CLIENT_NODE> {
pub async fn connect(&mut self) -> Result<(ServerInfo, Connection), MaybeArc<Error>> {
let mut handle_events_fut =
pin!(handle_events(&mut self.events_rx, self.event_callback.as_ref()).fuse());
let mut connect_fut = pin!(self.handler.connect());

loop {
tokio::select! {
result = connect_fut.as_mut() => {
return result;
},

() = handle_events_fut.as_mut() => {
tracing::warn!("events handler finished unexpectedly");
},
}
}
}

pub async fn try_connect(&mut self) -> Result<(ServerInfo, Connection), Error> {
let mut handle_events_fut =
pin!(handle_events(&mut self.events_rx, self.event_callback.as_ref()).fuse());
let mut try_connect_fut = pin!(self.handler.try_connect());

loop {
tokio::select! {
result = try_connect_fut.as_mut() => {
return result;
},

() = handle_events_fut.as_mut() => {
tracing::warn!("events handler finished unexpectedly");
},
}
}
}
}

impl<const CLIENT_NODE: bool> Handler<CLIENT_NODE> {
fn new<A: ToServerAddrs>(
addrs: A,
options: ConnectorOptions,
events_tx: tokio::sync::mpsc::Sender<Event>,
state_tx: tokio::sync::watch::Sender<State>,
max_payload: Arc<AtomicUsize>,
) -> Result<Connector, io::Error> {
) -> Result<Self, io::Error> {
let servers = addrs.to_server_addrs()?.map(|addr| (addr, 0)).collect();

Ok(Connector {
Ok(Handler {
attempts: 0,
servers,
options,
Expand Down Expand Up @@ -194,12 +330,17 @@ impl Connector {
}

let tls_required = self.options.tls_required || server_addr.tls_required();
let lang = if CLIENT_NODE {
LANG.to_string()
} else {
String::new()
};
let mut connect_info = ConnectInfo {
tls_required,
name: self.options.name.clone(),
pedantic: false,
verbose: false,
lang: LANG.to_string(),
lang,
version: VERSION.to_string(),
protocol: Protocol::Dynamic,
user: self.options.auth.username.clone(),
Expand Down
93 changes: 40 additions & 53 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@
#![deny(rustdoc::invalid_rust_codeblocks)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]

use connector::Connector;
use options::CallbackArg1;
use thiserror::Error;

use futures::stream::Stream;
Expand All @@ -210,7 +212,6 @@ use std::option;
use std::pin::Pin;
use std::slice;
use std::str::{self, FromStr};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::ErrorKind;
Expand All @@ -237,7 +238,6 @@ const MULTIPLEXER_SID: u64 = 0;
pub use tokio_rustls::rustls;

use connection::{Connection, State};
use connector::{Connector, ConnectorOptions};
pub use header::{HeaderMap, HeaderName, HeaderValue};
pub use subject::Subject;

Expand Down Expand Up @@ -436,7 +436,7 @@ struct Multiplexer {
/// A connection handler which facilitates communication from channels to a single shared connection.
pub(crate) struct ConnectionHandler {
connection: Connection,
connector: Connector,
connector: connector::Handler,
subscriptions: HashMap<u64, Subscription>,
multiplexer: Option<Multiplexer>,
pending_pings: usize,
Expand All @@ -449,7 +449,7 @@ pub(crate) struct ConnectionHandler {
impl ConnectionHandler {
pub(crate) fn new(
connection: Connection,
connector: Connector,
connector: connector::Handler,
info_sender: tokio::sync::watch::Sender<ServerInfo>,
ping_period: Duration,
) -> ConnectionHandler {
Expand Down Expand Up @@ -956,63 +956,38 @@ pub async fn connect_with_options<A: ToServerAddrs>(
addrs: A,
options: ConnectOptions,
) -> Result<Client, ConnectError> {
let ping_period = options.ping_interval;

let (events_tx, mut events_rx) = mpsc::channel(128);
let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
// We're setting it to the default server payload size.
let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));

let mut connector = Connector::new(
addrs,
ConnectorOptions {
tls_required: options.tls_required,
certificates: options.certificates,
client_key: options.client_key,
client_cert: options.client_cert,
tls_client_config: options.tls_client_config,
tls_first: options.tls_first,
auth: options.auth,
no_echo: options.no_echo,
connection_timeout: options.connection_timeout,
name: options.name,
ignore_discovered_servers: options.ignore_discovered_servers,
retain_servers_order: options.retain_servers_order,
read_buffer_capacity: options.read_buffer_capacity,
reconnect_delay_callback: options.reconnect_delay_callback,
auth_callback: options.auth_callback,
max_reconnects: options.max_reconnects,
},
events_tx,
state_tx,
max_payload.clone(),
)
.map_err(ConnectError::InvalidServerAddress)?;
let Connector {
mut handler,
mut events_rx,
state_rx,
subscription_capacity,
inbox_prefix,
request_timeout,
event_callback,
retry_on_initial_connect,
sender_capacity,
ping_interval,
} = connector::create(addrs, options).map_err(ConnectError::InvalidServerAddress)?;

let mut info: ServerInfo = Default::default();
let mut connection = None;
if !options.retry_on_initial_connect {
if !retry_on_initial_connect {
debug!("retry on initial connect failure is disabled");
let (info_ok, connection_ok) = connector.try_connect().await?;
let (info_ok, connection_ok) = handler.try_connect().await?;
connection = Some(connection_ok);
info = info_ok;
}

let max_payload = Arc::clone(&handler.max_payload);
let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
let (sender, mut receiver) = mpsc::channel(sender_capacity);

let events_handler_task = task::spawn(async move {
while let Some(event) = events_rx.recv().await {
tracing::info!("event: {}", event);
if let Some(event_callback) = &options.event_callback {
event_callback.call(event).await;
}
}
});
let events_handler_task =
task::spawn(async move { handle_events(&mut events_rx, event_callback.as_ref()).await });

let connection_handler_task = task::spawn(async move {
if connection.is_none() && options.retry_on_initial_connect {
let (info, connection_ok) = match connector.connect().await {
if connection.is_none() && retry_on_initial_connect {
let (info, connection_ok) = match handler.connect().await {
Ok((info, connection)) => (info, connection),
Err(err) => {
error!("connection closed: {}", err);
Expand All @@ -1024,24 +999,36 @@ pub async fn connect_with_options<A: ToServerAddrs>(
}
let connection = connection.unwrap();
let mut connection_handler =
ConnectionHandler::new(connection, connector, info_sender, ping_period);
ConnectionHandler::new(connection, handler, info_sender, ping_interval);
connection_handler.process(&mut receiver).await
});

Ok(client::Builder {
info: info_watcher,
state: state_rx,
sender,
capacity: options.subscription_capacity,
inbox_prefix: options.inbox_prefix,
request_timeout: options.request_timeout,
capacity: subscription_capacity,
inbox_prefix,
request_timeout,
max_payload,
events_handler_task,
connection_handler_task,
}
.build())
}

async fn handle_events(
events_rx: &mut mpsc::Receiver<Event>,
event_callback: Option<&CallbackArg1<Event, ()>>,
) {
while let Some(event) = events_rx.recv().await {
tracing::info!("event: {}", event);
if let Some(event_callback) = &event_callback {
event_callback.call(event).await;
}
}
}

#[derive(Debug)]
pub enum Event {
Connected,
Expand Down

0 comments on commit d5c7a5e

Please sign in to comment.