diff --git a/src/client/builder.rs b/src/client/builder.rs deleted file mode 100644 index 6221d74..0000000 --- a/src/client/builder.rs +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2020 Ben Ashford - * - * Licensed under the Apache License, Version 2.0 or the MIT license - * , at your - * option. This file may not be copied, modified, or distributed - * except according to those terms. - */ - -use std::net::{SocketAddr, ToSocketAddrs}; -use std::sync::Arc; - -use crate::error; - -#[derive(Debug)] -/// Connection builder -pub struct ConnectionBuilder { - pub(crate) addr: SocketAddr, - pub(crate) username: Option>, - pub(crate) password: Option>, -} - -impl ConnectionBuilder { - pub fn new(addr: A) -> Result { - Ok(Self { - addr: addr - .to_socket_addrs()? - .next() - .ok_or(error::Error::Connection( - error::ConnectionReason::ConnectionFailed, - ))?, - username: None, - password: None, - }) - } - - /// Set the username used when connecting - pub fn password>>(&mut self, password: V) -> &mut Self { - self.password = Some(password.into()); - self - } - - /// Set the password used when connecting - pub fn username>>(&mut self, username: V) -> &mut Self { - self.username = Some(username.into()); - self - } -} diff --git a/src/client/builder/mod.rs b/src/client/builder/mod.rs new file mode 100644 index 0000000..f4720e1 --- /dev/null +++ b/src/client/builder/mod.rs @@ -0,0 +1,13 @@ +use std::{future::Future, pin::Pin}; + +use crate::client::connect::RespConnection; +use crate::error::Error; + +pub mod sentinel; +pub mod redis; + +/// Creates primitive connection to redis. This connection can be later upgraded +/// to support request-response messaging (PairedConnection) or pub-sub (PubsubConnection). +pub trait ConnectionBuilder: Send + Sync + 'static { + fn connect<'a>(&'a mut self) -> Pin> + Send + 'a>>; +} diff --git a/src/client/builder/redis.rs b/src/client/builder/redis.rs new file mode 100644 index 0000000..b646ae8 --- /dev/null +++ b/src/client/builder/redis.rs @@ -0,0 +1,60 @@ +use std::{future::Future, net::ToSocketAddrs, pin::Pin}; + +use super::ConnectionBuilder; +use crate::client::connect::{self, RespConnection}; +use crate::error::Error; + +/// Builder connecting directly to Redis. +pub struct RedisConnectionBuilder { + address: String, + username: Option, + password: Option, +} + +impl RedisConnectionBuilder { + pub fn new(address: String) -> Self { + Self { + address, + username: None, + password: None, + } + } + + pub fn username(mut self, username: Option) -> Self { + self.username = username; + self + } + + pub fn password(mut self, password: Option) -> Self { + self.password = password; + self + } +} + +impl ConnectionBuilder for RedisConnectionBuilder { + fn connect<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let addresses = self + .address + .to_socket_addrs() + .map_err(|e| Error::Unexpected("Couldn't resolve redis address".into()))?; + + for address in addresses { + let conn = connect::connect_with_auth( + &address, + self.username.as_deref(), + self.password.as_deref(), + ) + .await; + + if let Ok(conn) = conn { + return Ok(conn); + } + } + + return Err(Error::Unexpected("Couldn't connect to redis".into())); + }) + } +} diff --git a/src/client/builder/sentinel.rs b/src/client/builder/sentinel.rs new file mode 100644 index 0000000..c2f69b7 --- /dev/null +++ b/src/client/builder/sentinel.rs @@ -0,0 +1,264 @@ +use std::net::SocketAddr; +use std::{future::Future, net::ToSocketAddrs, pin::Pin}; +use futures_util::sink::SinkExt; +use futures_util::stream::StreamExt; + +use connect::connect_with_auth; + +use super::ConnectionBuilder; +use crate::client::connect::{self, RespConnection}; +use crate::error::Error; +use crate::resp::RespValue; + +const SENTINEL_CONNECTION_TIMEOUT: u64 = 500; + +/// Builder connecting to Redis through Sentinel. +pub struct SentinelConnectionBuilder { + sentinel_addresses: Vec, + sentinel_username: Option, + sentinel_password: Option, + redis_master_name: String, + redis_username: Option, + redis_password: Option, +} + +impl SentinelConnectionBuilder { + pub fn new(sentinel_addresses: Vec, redis_master_name: String) -> Self { + Self { + sentinel_addresses, + sentinel_username: None, + sentinel_password: None, + redis_master_name, + redis_username: None, + redis_password: None, + } + } + + pub fn sentinel_username(mut self, username: Option) -> Self { + self.sentinel_username = username; + self + } + + pub fn sentinel_password(mut self, password: Option) -> Self { + self.sentinel_password = password; + self + } + + pub fn redis_username(mut self, username: Option) -> Self { + self.redis_username = username; + self + } + + pub fn redis_password(mut self, password: Option) -> Self { + self.redis_password = password; + self + } +} + +impl ConnectionBuilder for SentinelConnectionBuilder { + fn connect<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let mut furthest_error = DiscoveryError::SentinelAddressResolvingFailure; + + for i in 0..self.sentinel_addresses.len() { + let socket_addresses = match self.sentinel_addresses[i].to_socket_addrs() { + Ok(addresses) => addresses, + Err(e) => { + furthest_error = std::cmp::max(furthest_error, DiscoveryError::SentinelAddressResolvingFailure); + continue; + } + }; + + 'socket_address_loop: for address in socket_addresses { + match discover_redis_master(address, &self).await { + Ok(conn) => { + self.sentinel_addresses[0..=i].rotate_right(1); + return Ok(conn); + } + Err(error) => { + furthest_error = std::cmp::max(furthest_error, error); + + if error > DiscoveryError::SentinelDoesNotKnowMasterAddress { + self.sentinel_addresses[0..=i].rotate_right(1); + break 'socket_address_loop; + } + } + }; + } + } + + // Failed to connect to redis master through Sentinels. + tokio::time::sleep(std::time::Duration::from_millis( + SENTINEL_CONNECTION_TIMEOUT, + )) + .await; + Err(Error::Unexpected(format!( + "Redis discovery failed at {:?}", + furthest_error + ))) + }) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum DiscoveryError { + SentinelAddressResolvingFailure, + SentinelConnectionFailure, + SentinelCommunicationFailure, + SentinelDoesNotKnowMasterAddress, + RedisConnectionFailure, + RedisCommunicationFailure, + RedisIsNotMaster, +} + +async fn discover_redis_master( + sentinel_address: SocketAddr, + builder: &SentinelConnectionBuilder, +) -> Result { + let sentinel_timeout = tokio::time::timeout( + std::time::Duration::from_millis(SENTINEL_CONNECTION_TIMEOUT), + connect::connect_with_auth( + &sentinel_address, + builder.sentinel_username.as_deref(), + builder.sentinel_password.as_deref(), + ), + ); + + let mut sentinel_connection = match sentinel_timeout.await { + Ok(Ok(connection)) => connection, + // Connection failure or timeout, try next Sentinel + _ => { + return Err(DiscoveryError::SentinelConnectionFailure); + } + }; + + match sentinel_connection + .send(resp_array![ + "SENTINEL", + "get-master-addr-by-name", + &builder.redis_master_name + ]) + .await + { + Ok(_) => {} + // Could not send message, try next Sentinel + Err(_) => return Err(DiscoveryError::SentinelCommunicationFailure), + } + + let redis_master_address = match sentinel_connection.next().await { + Some(Ok(value)) => match master_address_from_resp_value(value) { + Ok(Some(address)) => address, + // Master's address not known, try next Sentinel + Ok(None) => return Err(DiscoveryError::SentinelDoesNotKnowMasterAddress), + // Bad response, try next Sentinel + Err(_) => return Err(DiscoveryError::SentinelCommunicationFailure), + }, + // Disconnected, or bad response, try next sentinel + _ => return Err(DiscoveryError::SentinelCommunicationFailure), + }; + + drop(sentinel_connection); + + let mut redis_connection = match connect_with_auth( + &redis_master_address, + builder.redis_username.as_deref(), + builder.redis_password.as_deref(), + ) + .await + { + Ok(connection) => connection, + // Redis unavailable, try from beginning + Err(_) => { + return Err(DiscoveryError::RedisConnectionFailure); + } + }; + + match redis_connection.send(resp_array!["ROLE"]).await { + Ok(_) => {} + // Could not send message, try next Sentinel + Err(_) => return Err(DiscoveryError::RedisCommunicationFailure), + } + + let role = match redis_connection.next().await { + Some(Ok(value)) => match role_from_resp_value(value) { + Ok(role) => role, + // bad response + Err(_) => return Err(DiscoveryError::RedisCommunicationFailure), + }, + // Disconnected or bad response + _ => return Err(DiscoveryError::RedisCommunicationFailure), + }; + + if role == "master" { + // Found master, return connection + // let (out_tx, out_rx) = mpsc::unbounded(); + // let paired_connection_inner = PairedConnectionInner::new(redis_connection, out_rx); + // tokio::spawn(paired_connection_inner); + // return Ok(out_tx); + todo!() + } else { + return Err(DiscoveryError::RedisIsNotMaster); + } +} + +/// Extracts master address from a response to SENTINEL get-master-addr-by-name +/// command. The function returns SocketAddress when address in known and +/// None otherwise. +fn master_address_from_resp_value(value: RespValue) -> Result, String> { + use std::net::IpAddr; + use std::str::FromStr; + + if let RespValue::Nil = value { + return Ok(None); + } + + let array = match value { + RespValue::Array(array) => array, + _ => return Err("Response is not an array".to_owned()), + }; + + let mut iter = array.into_iter(); + let (ip_raw, port_raw) = match (iter.next(), iter.next(), iter.next()) { + (Some(RespValue::BulkString(ip_raw)), Some(RespValue::BulkString(port_raw)), None) => { + (ip_raw, port_raw) + } + _ => return Err("Response array does not contain exactly two bulk strings".to_owned()), + }; + + let ip = match String::from_utf8(ip_raw).map(|s| IpAddr::from_str(&s)) { + Ok(Ok(ip)) => ip, + Ok(_) => return Err("Sentinel returned malformed IP".to_owned()), + _ => return Err("Sentinel returned non-utf-8 IP".to_owned()), + }; + + let port = match String::from_utf8(port_raw).map(|s| s.parse::()) { + Ok(Ok(ip)) => ip, + Ok(_) => return Err("Sentinel returned non-u16 port".to_owned()), + _ => return Err("Sentinel returned non-utf-8 port".to_owned()), + }; + + match (ip, port).to_socket_addrs() { + Ok(mut address_iterator) => Ok(Some(address_iterator.next().unwrap())), + _ => Err("Sentinel returned invalid Redis socket address".to_owned()), + } +} + +/// Extracts role string from a response to ROLE command. +fn role_from_resp_value(value: RespValue) -> Result { + let array = match value { + RespValue::Array(array) => array, + _ => return Err("Response is not an array".to_owned()), + }; + + let role_raw = match array.into_iter().next() { + Some(RespValue::BulkString(role_raw)) => role_raw, + _ => return Err("Response array does not start with a bulk string".to_owned()), + }; + + match String::from_utf8(role_raw) { + Ok(role) => Ok(role), + _ => return Err("Redis returned non-utf-8 role".to_owned()), + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 68fc214..69a73a6 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -23,12 +23,14 @@ pub mod connect; #[macro_use] pub mod paired; -mod builder; +pub mod builder; pub mod pubsub; +pub mod reconnect; pub use self::{ - builder::ConnectionBuilder, + builder::redis::RedisConnectionBuilder, + builder::sentinel::SentinelConnectionBuilder, connect::connect, - paired::{paired_connect, PairedConnection}, - pubsub::{pubsub_connect, PubsubConnection}, + paired::{paired_connect, paired_reconnecting, PairedConnection}, + pubsub::{pubsub_connect, pubsub_reconnecting, PubsubConnection}, }; diff --git a/src/client/paired.rs b/src/client/paired.rs index 623ac9e..98dbf4e 100644 --- a/src/client/paired.rs +++ b/src/client/paired.rs @@ -18,20 +18,17 @@ use std::task::{Context, Poll}; use futures_channel::{mpsc, oneshot}; use futures_sink::Sink; -use futures_util::{ - future::{self, TryFutureExt}, - stream::StreamExt, -}; +use futures_util::{future, stream::StreamExt}; -use super::{ - connect::{connect_with_auth, RespConnection}, - ConnectionBuilder, +use crate::client::{ + builder::{redis::RedisConnectionBuilder, ConnectionBuilder}, + connect::RespConnection, + reconnect::{ComplexConnection, Reconnecting}, }; use crate::{ - error, - reconnect::{reconnect, Reconnect}, - resp, + error::{self, Error}, + resp::{self, RespValue}, }; /// The state of sending messages to a Redis server @@ -71,18 +68,23 @@ struct PairedConnectionInner { /// The status of the underlying connection send_status: SendStatus, + + /// TODO add comment + error_sender: Option>, } impl PairedConnectionInner { fn new( con: RespConnection, out_rx: mpsc::UnboundedReceiver<(resp::RespValue, oneshot::Sender)>, + error_sender: tokio::sync::oneshot::Sender<()>, ) -> Self { PairedConnectionInner { connection: con, out_rx, waiting: VecDeque::new(), send_status: SendStatus::Ok, + error_sender: Some(error_sender), } } @@ -112,6 +114,10 @@ impl PairedConnectionInner { let message = match status { SendStatus::End => { self.send_status = SendStatus::End; + + let error_sender = std::mem::replace(&mut self.error_sender, None); + let _ = error_sender.map(|s| s.send(())); + return Ok(false); } SendStatus::Full(msg) => msg, @@ -143,7 +149,12 @@ impl PairedConnectionInner { } } match self.connection.poll_next_unpin(cx) { - Poll::Ready(None) => Err(error::unexpected("Connection to Redis closed unexpectedly")), + Poll::Ready(None) => { + let error_sender = std::mem::replace(&mut self.error_sender, None); + let _ = error_sender.map(|s| s.send(())); + + Err(error::unexpected("Connection to Redis closed unexpectedly")) + } Poll::Ready(Some(msg)) => { let tx = match self.waiting.pop_front() { Some(tx) => tx, @@ -196,43 +207,38 @@ impl Future for PairedConnectionInner { /// A shareable and cheaply cloneable connection to which Redis commands can be sent #[derive(Debug, Clone)] pub struct PairedConnection { - out_tx_c: Arc>>, + out_tx_c: Arc>, } -async fn inner_conn_fn( - addr: SocketAddr, - username: Option>, - password: Option>, -) -> Result, error::Error> { - let username = username.as_ref().map(|u| u.as_ref()); - let password = password.as_ref().map(|p| p.as_ref()); - let connection = connect_with_auth(&addr, username, password).await?; - let (out_tx, out_rx) = mpsc::unbounded(); - let paired_connection_inner = PairedConnectionInner::new(connection, out_rx); - tokio::spawn(paired_connection_inner); - Ok(out_tx) +pub trait PairedConnectionBuilder: ConnectionBuilder + Sized { + fn paired_connect<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>>; + fn paired_reconnecting( + self, + ) -> Pin< + Box, error::Error>> + Send>, + >; } -impl ConnectionBuilder { - pub fn paired_connect(&self) -> impl Future> { - let addr = self.addr; - let username = self.username.clone(); - let password = self.password.clone(); - - let work_fn = |con: &mpsc::UnboundedSender, act| { - con.unbounded_send(act).map_err(|e| e.into()) - }; - - let conn_fn = move || { - let con_f = inner_conn_fn(addr, username.clone(), password.clone()); - Box::pin(con_f) as Pin> + Send + Sync>> - }; - - let reconnecting_con = reconnect(work_fn, conn_fn); - reconnecting_con.map_ok(|con| PairedConnection { - out_tx_c: Arc::new(con), +impl PairedConnectionBuilder for B { + fn paired_connect<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let primitive = self.connect().await?; + let (tx, _rx) = tokio::sync::oneshot::channel(); + Ok(PairedConnection::from_primitive(primitive, tx)) }) } + + fn paired_reconnecting( + self, + ) -> Pin< + Box, error::Error>> + Send>, + > { + Box::pin(async move { Reconnecting::::start(self).await }) + } } /// The default starting point to use most default Redis functionality. @@ -240,16 +246,25 @@ impl ConnectionBuilder { /// Returns a future that resolves to a `PairedConnection`. The future will complete when the /// initial connection is established. /// -/// Once the initial connection is established, the connection will attempt to reconnect should -/// the connection be broken (e.g. the Redis server being restarted), but reconnections occur -/// asynchronously, so all commands issued while the connection is unavailable will error, it is -/// the client's responsibility to retry commands as applicable. Also, at least one command needs -/// to be tried against the connection to trigger the re-connection attempt; this means at least -/// one command will definitely fail in a disconnect/reconnect scenario. +/// This connection does **not** reconnect automatically. See Reconnect if you're +/// interested in reconnecting. pub async fn paired_connect(addr: SocketAddr) -> Result { - ConnectionBuilder::new(addr)?.paired_connect().await + RedisConnectionBuilder::new(addr.to_string()) + .paired_connect() + .await +} + +// TODO add doc comment +pub async fn paired_reconnecting( + addr: SocketAddr, +) -> Result, error::Error> { + RedisConnectionBuilder::new(addr.to_string()) + .paired_reconnecting() + .await } +// TODO add paired_connect_reconnecting? + impl PairedConnection { /// Sends a command to Redis. /// @@ -277,7 +292,7 @@ impl PairedConnection { } let (tx, rx) = oneshot::channel(); - match self.out_tx_c.do_work((msg, tx)) { + match self.out_tx_c.unbounded_send((msg, tx)) { Ok(()) => future::Either::Left(async move { match rx.await { Ok(v) => Ok(T::from_resp(v)?), @@ -286,7 +301,7 @@ impl PairedConnection { )), } }), - Err(e) => future::Either::Right(future::ready(Err(e))), + Err(e) => future::Either::Right(future::ready(Err(e.into()))), } } @@ -301,6 +316,39 @@ impl PairedConnection { } } +impl ComplexConnection for PairedConnection { + /// Creates paired connection by wrapping a primitive connection with pairing logic. + fn from_primitive( + primitive: RespConnection, + error_sender: tokio::sync::oneshot::Sender<()>, + ) -> Self { + let (out_tx, out_rx) = mpsc::unbounded(); + let paired_connection_inner = PairedConnectionInner::new(primitive, out_rx, error_sender); + tokio::spawn(paired_connection_inner); + + Self { + out_tx_c: Arc::new(out_tx), + } + } +} + +impl Reconnecting { + /// Sends message using the currently active connection. + /// This is a shorthand for `.current().await?.send`. + pub async fn send(&self, msg: RespValue) -> Result { + let connection = self.current().await?; + connection.send(msg).await + } + + /// Sends message using the currently active connection. + /// This is a shorthand for `.current().await?.send_and_forget`. + pub async fn send_and_forget(&self, msg: RespValue) { + if let Ok(connection) = self.current().await { + connection.send_and_forget(msg); + } + } +} + #[cfg(test)] mod test { use super::ConnectionBuilder; diff --git a/src/client/pubsub.rs b/src/client/pubsub.rs index 4a91574..f9f37b2 100644 --- a/src/client/pubsub.rs +++ b/src/client/pubsub.rs @@ -15,21 +15,21 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use futures_channel::{mpsc, oneshot}; -use futures_sink::Sink; -use futures_util::{ - future::TryFutureExt, - stream::{Fuse, Stream, StreamExt}, +use futures_channel::{ + mpsc::{self, TrySendError}, + oneshot, }; +use futures_sink::Sink; +use futures_util::stream::{Fuse, Stream, StreamExt}; -use super::{ - connect::{connect_with_auth, RespConnection}, - ConnectionBuilder, +use crate::client::{ + builder::{redis::RedisConnectionBuilder, ConnectionBuilder}, + connect::RespConnection, + reconnect::{ComplexConnection, Reconnecting}, }; - use crate::{ error::{self, ConnectionReason}, - reconnect::{reconnect, Reconnect}, + // reconnect::{reconnect, Reconnect}, resp::{self, FromResp}, }; @@ -63,10 +63,16 @@ struct PubsubConnectionInner { pending_psubs: BTreeMap)>, /// Any incomplete messages to be sent... send_pending: Option, + /// TODO add comment + error_sender: Option>, } impl PubsubConnectionInner { - fn new(con: RespConnection, out_rx: mpsc::UnboundedReceiver) -> Self { + fn new( + con: RespConnection, + out_rx: mpsc::UnboundedReceiver, + error_sender: tokio::sync::oneshot::Sender<()>, + ) -> Self { PubsubConnectionInner { connection: con, out_rx: out_rx.fuse(), @@ -75,6 +81,7 @@ impl PubsubConnectionInner { pending_subs: BTreeMap::new(), pending_psubs: BTreeMap::new(), send_pending: None, + error_sender: Some(error_sender), } } @@ -338,57 +345,58 @@ impl Future for PubsubConnectionInner { /// A shareable reference to subscribe to PUBSUB topics #[derive(Debug, Clone)] pub struct PubsubConnection { - out_tx_c: Arc>>, + out_tx_c: Arc>, } -async fn inner_conn_fn( - addr: SocketAddr, - username: Option>, - password: Option>, -) -> Result, error::Error> { - let username = username.as_ref().map(|u| u.as_ref()); - let password = password.as_ref().map(|p| p.as_ref()); - - let connection = connect_with_auth(&addr, username, password).await?; - let (out_tx, out_rx) = mpsc::unbounded(); - tokio::spawn(async { - match PubsubConnectionInner::new(connection, out_rx).await { - Ok(_) => (), - Err(e) => log::error!("Pub/Sub error: {:?}", e), - } - }); - Ok(out_tx) +pub trait PubSubConnectionBuilder: ConnectionBuilder + Sized { + fn pubsub_connect<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>>; + fn pubsub_reconnecting( + self, + ) -> Pin< + Box, error::Error>> + Send>, + >; } -impl ConnectionBuilder { - pub fn pubsub_connect(&self) -> impl Future> { - let addr = self.addr; - let username = self.username.clone(); - let password = self.password.clone(); - - let reconnecting_f = reconnect( - |con: &mpsc::UnboundedSender, act| { - con.unbounded_send(act).map_err(|e| e.into()) - }, - move || { - let con_f = inner_conn_fn(addr, username.clone(), password.clone()); - Box::pin(con_f) - }, - ); - reconnecting_f.map_ok(|con| PubsubConnection { - out_tx_c: Arc::new(con), +impl PubSubConnectionBuilder for B { + fn pubsub_connect<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let primitive = self.connect().await?; + let (tx, _rx) = tokio::sync::oneshot::channel(); + Ok(PubsubConnection::from_primitive(primitive, tx)) }) } + fn pubsub_reconnecting( + self, + ) -> Pin< + Box, error::Error>> + Send>, + > { + Box::pin(async { Reconnecting::start(self).await }) + } } /// Used for Redis's PUBSUB functionality. /// /// Returns a future that resolves to a `PubsubConnection`. The future will only resolve once the -/// connection is established; after the intial establishment, if the connection drops for any -/// reason (e.g. Redis server being restarted), the connection will attempt re-connect, however -/// any subscriptions will need to be re-subscribed. +/// connection is established; after the initial establishment, if the connection drops for any +/// reason (e.g. Redis server being restarted), the connection will **not** attempt re-connect. +/// Check out pubsub_reconnecting if you're interested in auto-reconnect. pub async fn pubsub_connect(addr: SocketAddr) -> Result { - ConnectionBuilder::new(addr)?.pubsub_connect().await + RedisConnectionBuilder::new(addr.to_string()) + .pubsub_connect() + .await +} + +// TODO add doc comment +pub async fn pubsub_reconnecting( + addr: SocketAddr, +) -> Result, error::Error> { + RedisConnectionBuilder::new(addr.to_string()) + .pubsub_reconnecting() + .await } impl PubsubConnection { @@ -406,7 +414,7 @@ impl PubsubConnection { let (tx, rx) = mpsc::unbounded(); let (signal_t, signal_r) = oneshot::channel(); self.out_tx_c - .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?; + .unbounded_send(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?; match signal_r.await { Ok(_) => Ok(PubsubStream { @@ -422,7 +430,8 @@ impl PubsubConnection { let (tx, rx) = mpsc::unbounded(); let (signal_t, signal_r) = oneshot::channel(); self.out_tx_c - .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?; + .unbounded_send(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t)) + .map_err(send_error_to_redis_error)?; match signal_r.await { Ok(_) => Ok(PubsubStream { @@ -441,7 +450,7 @@ impl PubsubConnection { // anyway, and would be reported/logged elsewhere let _ = self .out_tx_c - .do_work(PubsubEvent::Unsubscribe(topic.into())); + .unbounded_send(PubsubEvent::Unsubscribe(topic.into())); } pub fn punsubscribe>(&self, topic: T) { @@ -449,7 +458,30 @@ impl PubsubConnection { // anyway, and would be reported/logged elsewhere let _ = self .out_tx_c - .do_work(PubsubEvent::Punsubscribe(topic.into())); + .unbounded_send(PubsubEvent::Punsubscribe(topic.into())); + } +} + +fn send_error_to_redis_error(e: TrySendError) -> error::Error { + e.into() +} + +impl ComplexConnection for PubsubConnection { + fn from_primitive( + primitive: RespConnection, + error_sender: tokio::sync::oneshot::Sender<()>, + ) -> Self { + let (out_tx, out_rx) = mpsc::unbounded(); + tokio::spawn(async { + // TODO make sure this error triggers a reconnect + match PubsubConnectionInner::new(primitive, out_rx, error_sender).await { + Ok(_) => (), + Err(e) => log::error!("Pub/Sub error: {:?}", e), + } + }); + Self { + out_tx_c: Arc::new(out_tx), + } } } diff --git a/src/client/reconnect.rs b/src/client/reconnect.rs new file mode 100644 index 0000000..3ebab7b --- /dev/null +++ b/src/client/reconnect.rs @@ -0,0 +1,121 @@ +use std::{future::Future, pin::Pin, sync::Arc}; +use tokio::sync::{Mutex, RwLock}; + +use super::{builder::ConnectionBuilder, connect::RespConnection}; +use crate::error::Error; + +/// Connection that can be constructed from RespConnection. +/// It must be cheaply cloneable. +pub trait ComplexConnection: Clone { + fn from_primitive( + primitive: RespConnection, + error_sender: tokio::sync::oneshot::Sender<()>, + ) -> Self; +} + +enum ReconnectionState { + NotConnected, + Connecting, + Connected(T), + ConnectionFailed(Error), +} + +struct ReconnectingInner { + // Connection builder for generating new connections. + builder: Mutex, + // Connection state. + state: RwLock>, +} + +impl ReconnectingInner +where + B: ConnectionBuilder + Send + Sync + 'static, + C: ComplexConnection + Send + Sync + 'static, +{ + /// Creates a new connection using the builder and replaces the old one. + /// It must return a boxed future, as it makes recursive calls. + fn reconnect(self: Arc) -> Pin> + Send>> { + Box::pin(async move { + log::info!("Reconnecting"); + let state = self.state.read().await; + if let ReconnectionState::Connecting = *state { + // if already connecting, ignore + // TODO should we return ok here? + return Ok(()); + } + drop(state); + + let mut builder = self.builder.lock().await; + let attempts: u32 = 100; + let mut attempt: u32 = 1; + loop { + if attempt > attempts { + let mut state = self.state.write().await; + *state = ReconnectionState::ConnectionFailed(Error::Unexpected(format!( + "Failed after {} connection attempts", + attempts + ))); + return Err(Error::Unexpected("all 10 connection attempts failed".into())); + } + if let Ok(connection) = builder.connect().await { + let (tx, rx) = tokio::sync::oneshot::channel(); + let inner = Arc::clone(&self); + tokio::spawn(async move { + if let Ok(_) = rx.await { + tokio::spawn(inner.reconnect()); + } + }); + let mut state = self.state.write().await; + *state = ReconnectionState::Connected(C::from_primitive(connection, tx)); + return Ok(()); + } + log::warn!("Reconnect attempt {} failed", attempt); + tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + attempt += 1; + } + }) + } +} + +/// Wraps ComplexConnection and provides automatic reconnection. +/// Uses ConnectionBuilder to generate new connections. +#[derive(Clone)] +pub struct Reconnecting { + inner: Arc>, +} + +impl Reconnecting +where + B: ConnectionBuilder + Send + Sync + 'static, + C: ComplexConnection + Send + Sync + 'static, +{ + /// Constructs Reconnecting client and immediately connects to Redis. + /// TODO we want Error when the first connect fails, right? + pub async fn start(builder: B) -> Result { + let inner = ReconnectingInner { + builder: Mutex::new(builder), + state: RwLock::new(ReconnectionState::NotConnected), + }; + let reconnecting = Self { + inner: Arc::new(inner), + }; + + Arc::clone(&reconnecting.inner).reconnect().await?; + + Ok(reconnecting) + } + + /// Returns the active connection or an error if not connected. + /// Because the returned connection does **not** reconnect on error, + /// you should not keep it for too long. Call `current` before each + /// send to make sure you use a working connection. + pub async fn current(&self) -> Result { + match &*self.inner.state.read().await { + ReconnectionState::Connected(connection) => Ok(connection.clone()), + _ => Err(Error::Unexpected("Connecting/reconnecting".into())), + } + } + + // TODO should we expose reconnect method to allow users to force reconnect? + // It could be useful, e.g. for when all auto-reconnect attempts fail +} diff --git a/src/lib.rs b/src/lib.rs index bcae9a5..7b9745f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,5 +74,3 @@ pub mod resp; pub mod client; pub mod error; - -pub(crate) mod reconnect; diff --git a/src/reconnect.rs b/src/reconnect.rs deleted file mode 100644 index e61912d..0000000 --- a/src/reconnect.rs +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright 2018-2020 Ben Ashford - * - * Licensed under the Apache License, Version 2.0 or the MIT license - * , at your - * option. This file may not be copied, modified, or distributed - * except according to those terms. - */ - -use std::fmt; -use std::future::Future; -use std::mem; -use std::pin::Pin; -use std::sync::{Arc, Mutex, MutexGuard}; -use std::time::Duration; - -use futures_util::{ - future::{self, Either}, - TryFutureExt, -}; - -use tokio::time::timeout; - -use crate::error::{self, ConnectionReason}; - -type WorkFn = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync; -type ConnFn = - dyn Fn() -> Pin> + Send + Sync>> + Send + Sync; - -struct ReconnectInner { - state: Mutex>, - work_fn: Box>, - conn_fn: Box>, -} - -pub(crate) struct Reconnect(Arc>); - -impl Clone for Reconnect { - fn clone(&self) -> Self { - Reconnect(self.0.clone()) - } -} - -impl fmt::Debug for Reconnect -where - T: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Reconnect") - .field("state", &self.0.state) - .field("work_fn", &String::from("REDACTED")) - .field("conn_fn", &String::from("REDACTED")) - .finish() - } -} - -pub(crate) async fn reconnect(w: W, c: C) -> Result, error::Error> -where - A: Send + 'static, - W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static, - C: Fn() -> Pin> + Send + Sync>> - + Send - + Sync - + 'static, - T: Clone + Send + Sync + 'static, -{ - let r = Reconnect(Arc::new(ReconnectInner { - state: Mutex::new(ReconnectState::NotConnected), - - work_fn: Box::new(w), - conn_fn: Box::new(c), - })); - let rf = { - let state = r.0.state.lock().expect("Poisoned lock"); - r.reconnect(state) - }; - rf.await?; - Ok(r) -} - -enum ReconnectState { - NotConnected, - Connected(T), - ConnectionFailed(Mutex>), - Connecting, -} - -impl fmt::Debug for ReconnectState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ReconnectState::")?; - match self { - NotConnected => write!(f, "NotConnected"), - Connected(_) => write!(f, "Connected"), - ConnectionFailed(_) => write!(f, "ConnectionFailed"), - Connecting => write!(f, "Connecting"), - } - } -} - -use self::ReconnectState::*; - -const CONNECTION_TIMEOUT_SECONDS: u64 = 10; -const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS); - -impl Reconnect -where - A: Send + 'static, - T: Clone + Send + Sync + 'static, -{ - fn call_work(&self, t: &T, a: A) -> Result { - if let Err(e) = (self.0.work_fn)(t, a) { - match e { - error::Error::IO(_) | error::Error::Unexpected(_) => { - log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e); - return Ok(false); - } - _ => (), - } - Err(e) - } else { - Ok(true) - } - } - - pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> { - let mut state = self.0.state.lock().expect("Cannot obtain read lock"); - match *state { - NotConnected => { - self.reconnect_spawn(state); - Err(error::Error::Connection(ConnectionReason::NotConnected)) - } - Connected(ref t) => { - let success = self.call_work(t, a)?; - if !success { - *state = NotConnected; - self.reconnect_spawn(state); - } - Ok(()) - } - ConnectionFailed(ref e) => { - let mut lock = e.lock().expect("Poisioned lock"); - let e = match lock.take() { - Some(e) => e, - None => error::Error::Connection(ConnectionReason::NotConnected), - }; - mem::drop(lock); - - *state = NotConnected; - self.reconnect_spawn(state); - Err(e) - } - Connecting => Err(error::Error::Connection(ConnectionReason::Connecting)), - } - } - - /// Returns a future that completes when the connection is established or failed to establish - /// used only for timing. - fn reconnect( - &self, - mut state: MutexGuard>, - ) -> impl Future> + Send { - log::info!("Attempting to reconnect, current state: {:?}", *state); - - match *state { - Connected(_) => { - return Either::Right(future::err(error::Error::Connection( - ConnectionReason::Connected, - ))); - } - Connecting => { - return Either::Right(future::err(error::Error::Connection( - ConnectionReason::Connecting, - ))); - } - NotConnected | ConnectionFailed(_) => (), - } - *state = ReconnectState::Connecting; - - mem::drop(state); - - let reconnect = self.clone(); - - let connection_f = async move { - let connection = match timeout(CONNECTION_TIMEOUT, (reconnect.0.conn_fn)()).await { - Ok(con_r) => con_r, - Err(_) => Err(error::internal(format!( - "Connection timed-out after {} seconds", - CONNECTION_TIMEOUT_SECONDS - ))), - }; - - let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock"); - - match *state { - NotConnected | Connecting => match connection { - Ok(t) => { - log::info!("Connection established"); - *state = Connected(t); - Ok(()) - } - Err(e) => { - log::error!("Connection cannot be established: {}", e); - *state = ConnectionFailed(Mutex::new(Some(e))); - Err(error::Error::Connection(ConnectionReason::ConnectionFailed)) - } - }, - ConnectionFailed(_) => { - panic!("The connection state wasn't reset before connecting") - } - Connected(_) => panic!("A connected state shouldn't be attempting to reconnect"), - } - }; - - Either::Left(connection_f) - } - - fn reconnect_spawn(&self, state: MutexGuard>) { - let reconnect_f = self - .reconnect(state) - .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e)); - - tokio::spawn(reconnect_f); - } -}