From 7f087c5384e3d3f5aa8a661dca3953c81fb94286 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 19 Mar 2024 09:50:43 +0100 Subject: [PATCH 01/12] fix(rumqttc): do not store the request tx handle in the event loop Fixes #815 --- rumqttc/src/v5/client.rs | 6 +++--- rumqttc/src/v5/eventloop.rs | 16 ++++++---------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c5..7d8a36bc8 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -11,7 +11,7 @@ use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::valid_topic; use bytes::Bytes; -use flume::{SendError, Sender, TrySendError}; +use flume::{bounded, SendError, Sender, TrySendError}; use futures_util::FutureExt; use tokio::runtime::{self, Runtime}; use tokio::time::timeout; @@ -54,8 +54,8 @@ impl AsyncClient { /// /// `cap` specifies the capacity of the bounded async channel. pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { - let eventloop = EventLoop::new(options, cap); - let request_tx = eventloop.requests_tx.clone(); + let (request_tx, request_rx) = bounded(cap); + let eventloop = EventLoop::new(options, request_rx); let client = AsyncClient { request_tx }; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index ab1edb17c..f8a582f3f 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -4,7 +4,7 @@ use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Tra use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; -use flume::{bounded, Receiver, Sender}; +use flume::Receiver; use tokio::select; use tokio::time::{self, error::Elapsed, Instant, Sleep}; @@ -73,9 +73,7 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, - /// Requests handle to send requests - pub(crate) requests_tx: Sender, + request_rx: Receiver, /// Pending packets from last session pub pending: VecDeque, /// Network connection to the broker @@ -96,8 +94,7 @@ impl EventLoop { /// /// When connection encounters critical errors (like auth failure), user has a choice to /// access and update `options`, `state` and `requests`. - pub fn new(options: MqttOptions, cap: usize) -> EventLoop { - let (requests_tx, requests_rx) = bounded(cap); + pub fn new(options: MqttOptions, request_rx: Receiver) -> EventLoop { let pending = VecDeque::new(); let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; @@ -105,8 +102,7 @@ impl EventLoop { EventLoop { options, state: MqttState::new(inflight_limit, manual_acks), - requests_tx, - requests_rx, + request_rx, pending, network: None, keepalive_timeout: None, @@ -126,7 +122,7 @@ impl EventLoop { self.pending.extend(self.state.clean()); // drain requests from channel which weren't yet received - let requests_in_channel = self.requests_rx.drain(); + let requests_in_channel = self.request_rx.drain(); self.pending.extend(requests_in_channel); } @@ -205,7 +201,7 @@ impl EventLoop { // outgoing requests (along with 1b). o = Self::next_request( &mut self.pending, - &self.requests_rx, + &self.request_rx, self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { From a9767c4de196eb0e5ce7ec29296bef8649046441 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 19 Mar 2024 09:56:38 +0100 Subject: [PATCH 02/12] feat(rumqttc): declare rx batch size in event loop This number shall be reused for the tx side as well. --- rumqttc/src/v5/eventloop.rs | 5 ++++- rumqttc/src/v5/framed.rs | 15 +++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index f8a582f3f..4c8002b75 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -31,6 +31,9 @@ use { #[cfg(feature = "proxy")] use crate::proxy::ProxyError; +/// Number of packets or requests processed before flusing the network +const BATCH_SIZE: usize = 10; + /// Critical errors during eventloop polling #[derive(Debug, thiserror::Error)] pub enum ConnectionError { @@ -212,7 +215,7 @@ impl EventLoop { Err(_) => Err(ConnectionError::RequestsDone), }, // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { + o = network.readb(&mut self.state, BATCH_SIZE) => { o?; // flush all the acks and return first incoming packet network.flush().await?; diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index c7e06a250..0fcec0004 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -14,8 +14,6 @@ use super::{Incoming, StateError}; pub struct Network { /// Frame MQTT packets from network connection framed: Framed, Codec>, - /// Maximum readv count - max_readb_count: usize, } impl Network { pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: Option) -> Network { @@ -26,10 +24,7 @@ impl Network { }; let framed = Framed::new(socket, codec); - Network { - framed, - max_readb_count: 10, - } + Network { framed } } pub fn set_max_outgoing_size(&mut self, max_outgoing_size: Option) { @@ -48,7 +43,11 @@ impl Network { /// Read packets in bulk. This allow replies to be in bulk. This method is used /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { + pub async fn readb( + &mut self, + state: &mut MqttState, + batch_size: usize, + ) -> Result<(), StateError> { // wait for the first read let mut res = self.framed.next().await; let mut count = 1; @@ -60,7 +59,7 @@ impl Network { } count += 1; - if count >= self.max_readb_count { + if count >= batch_size { break; } } From 194bb5c1c68f9b8f15dd8ea978d236fb5dff30ce Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 19 Mar 2024 10:40:53 +0100 Subject: [PATCH 03/12] fix(rumqttc): fix possible starvation with pending requests Store the pending throttle interval within the EventLoop. Fixes: #814 --- rumqttc/src/v5/eventloop.rs | 88 +++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 4c8002b75..3e8519693 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -5,12 +5,14 @@ use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; use flume::Receiver; +use futures_util::{Stream, StreamExt}; use tokio::select; use tokio::time::{self, error::Elapsed, Instant, Sleep}; use std::collections::VecDeque; use std::io; use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use super::mqttbytes::v5::ConnectReturnCode; @@ -77,8 +79,8 @@ pub struct EventLoop { pub state: MqttState, /// Request stream request_rx: Receiver, - /// Pending packets from last session - pub pending: VecDeque, + /// Pending requests from the last session + pending: PendingRequests, /// Network connection to the broker network: Option, /// Keep alive time @@ -98,9 +100,9 @@ impl EventLoop { /// When connection encounters critical errors (like auth failure), user has a choice to /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, request_rx: Receiver) -> EventLoop { - let pending = VecDeque::new(); let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; + let pending = PendingRequests::new(options.pending_throttle); EventLoop { options, @@ -161,8 +163,6 @@ impl EventLoop { /// Select on network and requests and generate keepalive pings when necessary async fn select(&mut self) -> Result { let network = self.network.as_mut().unwrap(); - // let await_acks = self.state.await_acks; - let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight; let collision = self.state.collision.is_some(); @@ -202,17 +202,16 @@ impl EventLoop { // After collision with pkid 1 -> [1b ,2, x, 4, 5]. // 1a is saved to state and event loop is set to collision mode stopping new // outgoing requests (along with 1b). - o = Self::next_request( - &mut self.pending, - &self.request_rx, - self.options.pending_throttle - ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - self.state.handle_outgoing_packet(request)?; - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - } - Err(_) => Err(ConnectionError::RequestsDone), + Some(request) = self.pending.next(), if !inflight_full && !collision => { + self.state.handle_outgoing_packet(request)?; + network.flush().await?; + Ok(self.state.events.pop_front().unwrap()) + }, + request = self.request_rx.recv_async(), if self.pending.is_empty() && !inflight_full && !collision => { + let request = request.map_err(|_| ConnectionError::RequestsDone)?; + self.state.handle_outgoing_packet(request)?; + network.flush().await?; + Ok(self.state.events.pop_front().unwrap()) }, // Pull a bunch of packets from network, reply in bunch and yield the first item o = network.readb(&mut self.state, BATCH_SIZE) => { @@ -231,26 +230,57 @@ impl EventLoop { network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } + else => unreachable!("Eventloop select is exhaustive"), } } +} - async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, - pending_throttle: Duration, - ) -> Result { - if !pending.is_empty() { - time::sleep(pending_throttle).await; - // We must call .next() AFTER sleep() otherwise .next() would - // advance the iterator but the future might be canceled before return - Ok(pending.pop_front().unwrap()) +/// Pending requets yielded with a configured rate. If the queue is empty the stream will yield pending. +struct PendingRequests { + /// Interval + interval: Option, + /// Pending requests + requests: VecDeque, +} + +impl PendingRequests { + pub fn new(interval: Duration) -> Self { + let interval = (!interval.is_zero()).then(|| time::interval(interval)); + PendingRequests { + interval, + requests: VecDeque::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.requests.is_empty() + } + + pub fn extend(&mut self, requests: impl IntoIterator) { + self.requests.extend(requests); + } +} + +impl Stream for PendingRequests { + type Item = Request; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_empty() { + Poll::Pending } else { - match rx.recv_async().await { - Ok(r) => Ok(r), - Err(_) => Err(ConnectionError::RequestsDone), + match self.interval.as_mut() { + Some(interval) => match interval.poll_tick(cx) { + Poll::Ready(_) => Poll::Ready(self.requests.pop_front()), + Poll::Pending => Poll::Pending, + }, + None => Poll::Ready(self.requests.pop_front()), } } } + + fn size_hint(&self) -> (usize, Option) { + (self.requests.len(), Some(self.requests.len())) + } } /// This stream internally processes requests from the request stream provided to the eventloop From 520a957174a280a1ca4359d9047a3bd18eaad7cc Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 19 Mar 2024 12:32:57 +0100 Subject: [PATCH 04/12] feat(rumqttc): batch request processing before flushing the network Collect up to BATCH_SIZE request from the client channel before flushing the network. Fixes #810 --- rumqttc/CHANGELOG.md | 1 + rumqttc/src/v5/eventloop.rs | 34 ++++++++++++++++++++++++++++++++-- rumqttc/src/v5/state.rs | 10 ++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 1045cfcf1..f2e80ca6d 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Process multiple outgoing client requests before flushing the network buffer (reduces number of system calls) * `size()` method on `Packet` calculates size once serialized. * `read()` and `write()` methods on `Packet`. diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 3e8519693..eec82f3fd 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -163,8 +163,8 @@ impl EventLoop { /// Select on network and requests and generate keepalive pings when necessary async fn select(&mut self) -> Result { let network = self.network.as_mut().unwrap(); - let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight; - let collision = self.state.collision.is_some(); + let inflight_full = self.state.is_inflight_full(); + let collision = self.state.has_collision(); // Read buffered events from previous polls before calling a new poll if let Some(event) = self.state.events.pop_front() { @@ -208,8 +208,38 @@ impl EventLoop { Ok(self.state.events.pop_front().unwrap()) }, request = self.request_rx.recv_async(), if self.pending.is_empty() && !inflight_full && !collision => { + // Process first request let request = request.map_err(|_| ConnectionError::RequestsDone)?; self.state.handle_outgoing_packet(request)?; + + // Take up to BATCH_SIZE - 1 requests from the channel until + // - the channel is empty + // - the inflight queue is full + // - there is a collision + // If the channel is closed this is reported in the next iteration from the async recv above. + for _ in 0..(BATCH_SIZE - 1) { + if self.request_rx.is_empty() || self.state.is_inflight_full() || self.state.has_collision() + { + break; + } + + // Safe to call the blocking `recv` in here since we know the channel is not empty. + // Ensure a flush in case of any error. + if let Err(e) = self + .request_rx + .recv() + .map_err(|_| ConnectionError::RequestsDone) + .and_then(|request| { + self.state + .handle_outgoing_packet(request) + .map_err(Into::into) + }) + { + network.flush().await?; + return Err(e); + } + } + network.flush().await?; Ok(self.state.events.pop_front().unwrap()) }, diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 854aa7b0f..95201bde6 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -185,6 +185,16 @@ impl MqttState { self.inflight } + /// Returns true if the inflight limit is reached + pub fn is_inflight_full(&self) -> bool { + self.inflight >= self.max_outgoing_inflight + } + + /// Returns true if the state has a unresolved collision + pub fn has_collision(&self) -> bool { + self.collision.is_some() + } + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop pub fn handle_outgoing_packet( From 6e3238d7eda1fce3ee86b352d9e7e98e3d0b52eb Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 26 Mar 2024 08:45:49 +0100 Subject: [PATCH 05/12] feat(rumqttc): prepare centralized batching --- rumqttc/src/v5/eventloop.rs | 36 ++++-------------------- rumqttc/src/v5/framed.rs | 55 +++++++------------------------------ 2 files changed, 15 insertions(+), 76 deletions(-) diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index eec82f3fd..ee3d88961 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -211,41 +211,15 @@ impl EventLoop { // Process first request let request = request.map_err(|_| ConnectionError::RequestsDone)?; self.state.handle_outgoing_packet(request)?; - - // Take up to BATCH_SIZE - 1 requests from the channel until - // - the channel is empty - // - the inflight queue is full - // - there is a collision - // If the channel is closed this is reported in the next iteration from the async recv above. - for _ in 0..(BATCH_SIZE - 1) { - if self.request_rx.is_empty() || self.state.is_inflight_full() || self.state.has_collision() - { - break; - } - - // Safe to call the blocking `recv` in here since we know the channel is not empty. - // Ensure a flush in case of any error. - if let Err(e) = self - .request_rx - .recv() - .map_err(|_| ConnectionError::RequestsDone) - .and_then(|request| { - self.state - .handle_outgoing_packet(request) - .map_err(Into::into) - }) - { - network.flush().await?; - return Err(e); - } - } - network.flush().await?; Ok(self.state.events.pop_front().unwrap()) }, // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state, BATCH_SIZE) => { - o?; + packet = network.read() => { + let packet = packet?; + if let Some(packet) = self.state.handle_incoming_packet(packet)? { + network.write(packet).await?; + } // flush all the acks and return first incoming packet network.flush().await?; Ok(self.state.events.pop_front().unwrap()) diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 0fcec0004..eebfed566 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,11 +1,11 @@ -use futures_util::{FutureExt, SinkExt}; +use futures_util::SinkExt; use tokio_stream::StreamExt; use tokio_util::codec::Framed; use crate::framed::AsyncReadWrite; use super::mqttbytes::v5::Packet; -use super::{mqttbytes, Codec, Connect, MqttOptions, MqttState}; +use super::{mqttbytes, Codec, Connect, Login, MqttOptions}; use super::{Incoming, StateError}; /// Network transforms packets <-> frames efficiently. It takes @@ -41,42 +41,6 @@ impl Network { } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb( - &mut self, - state: &mut MqttState, - batch_size: usize, - ) -> Result<(), StateError> { - // wait for the first read - let mut res = self.framed.next().await; - let mut count = 1; - loop { - match res { - Some(Ok(packet)) => { - if let Some(outgoing) = state.handle_incoming_packet(packet)? { - self.write(outgoing).await?; - } - - count += 1; - if count >= batch_size { - break; - } - } - Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), - Some(Err(e)) => return Err(StateError::Deserialization(e)), - None => return Err(StateError::ConnectionAborted), - } - // do not wait for subsequent reads - match self.framed.next().now_or_never() { - Some(r) => res = r, - _ => break, - }; - } - - Ok(()) - } - /// Serializes packet into write buffer pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { self.framed @@ -85,6 +49,14 @@ impl Network { .map_err(StateError::Deserialization) } + /// Flush the outgoing sink + pub async fn flush(&mut self) -> Result<(), StateError> { + self.framed + .flush() + .await + .map_err(StateError::Deserialization) + } + pub async fn connect( &mut self, connect: Connect, @@ -97,11 +69,4 @@ impl Network { self.flush().await } - - pub async fn flush(&mut self) -> Result<(), StateError> { - self.framed - .flush() - .await - .map_err(StateError::Deserialization) - } } From d4fe99b7ac204168a2ed0f7c74a1e197ed3bb674 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 26 Mar 2024 16:03:47 +0100 Subject: [PATCH 06/12] fix(rumqttc): reserve buffer space before serializing a v5 packet --- rumqttc/src/v5/mqttbytes/v5/mod.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 342278596..d2cab8a8e 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -129,8 +129,10 @@ impl Packet { } pub fn write(&self, write: &mut BytesMut, max_size: Option) -> Result { + let size = self.size(); + if let Some(max_size) = max_size { - if self.size() > max_size as usize { + if size > max_size as usize { return Err(Error::OutgoingPacketTooLarge { pkt_size: self.size() as u32, max: max_size, @@ -138,6 +140,9 @@ impl Packet { } } + // Ensure that `write` can take the serialized packet + write.reserve(size); + match self { Self::Publish(publish) => publish.write(write), Self::Subscribe(subscription) => subscription.write(write), From 32bc49e7cfb5aeefeee234577d2779c656b71cac Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 26 Mar 2024 16:04:19 +0100 Subject: [PATCH 07/12] feat(rumqttc): batch processing --- rumqttc/src/v5/eventloop.rs | 259 +++++++++++++++------------ rumqttc/src/v5/framed.rs | 14 +- rumqttc/src/v5/mod.rs | 14 +- rumqttc/src/v5/mqttbytes/v5/codec.rs | 63 ++++++- 4 files changed, 217 insertions(+), 133 deletions(-) diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index ee3d88961..e109a6bf4 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -33,9 +33,6 @@ use { #[cfg(feature = "proxy")] use crate::proxy::ProxyError; -/// Number of packets or requests processed before flusing the network -const BATCH_SIZE: usize = 10; - /// Critical errors during eventloop polling #[derive(Debug, thiserror::Error)] pub enum ConnectionError { @@ -56,6 +53,8 @@ pub enum ConnectionError { Io(#[from] io::Error), #[error("Connection refused, return code: `{0:?}`")] ConnectionRefused(ConnectReturnCode), + #[error("Connection closed")] + ConnectionClosed, #[error("Expected ConnAck packet, received: {0:?}")] NotConnAck(Box), #[error("Requests done")] @@ -77,10 +76,12 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, + /// Batch size + batch_size: usize, /// Request stream - request_rx: Receiver, + requests: Receiver, /// Pending requests from the last session - pending: PendingRequests, + pending: IntervalQueue, /// Network connection to the broker network: Option, /// Keep alive time @@ -96,18 +97,18 @@ pub enum Event { impl EventLoop { /// New MQTT `EventLoop` - /// - /// When connection encounters critical errors (like auth failure), user has a choice to - /// access and update `options`, `state` and `requests`. - pub fn new(options: MqttOptions, request_rx: Receiver) -> EventLoop { + pub(crate) fn new(options: MqttOptions, requests: Receiver) -> EventLoop { let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; - let pending = PendingRequests::new(options.pending_throttle); + let pending = IntervalQueue::new(options.pending_throttle); + let batch_size = options.max_batch_size; + let state = MqttState::new(inflight_limit, manual_acks); EventLoop { options, - state: MqttState::new(inflight_limit, manual_acks), - request_rx, + state, + batch_size, + requests, pending, network: None, keepalive_timeout: None, @@ -127,8 +128,7 @@ impl EventLoop { self.pending.extend(self.state.clean()); // drain requests from channel which weren't yet received - let requests_in_channel = self.request_rx.drain(); - self.pending.extend(requests_in_channel); + self.pending.extend(self.requests.drain()); } /// Yields Next notification or outgoing request and periodically pings @@ -148,142 +148,176 @@ impl EventLoop { self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); } + // A connack never produces a response packet. Safe to ignore the return value + // of `handle_incoming_packet` self.state.handle_incoming_packet(connack)?; + self.pending.reset(); } - match self.select().await { - Ok(v) => Ok(v), - Err(e) => { - self.clean(); - Err(e) + // Read buffered events from previous polls before calling a new poll + if let Some(event) = self.state.events.pop_front() { + Ok(event) + } else { + match self.poll_process().await { + Ok(v) => Ok(v), + Err(e) => { + self.clean(); + Err(e) + } } } } /// Select on network and requests and generate keepalive pings when necessary - async fn select(&mut self) -> Result { + async fn poll_process(&mut self) -> Result { let network = self.network.as_mut().unwrap(); - let inflight_full = self.state.is_inflight_full(); - let collision = self.state.has_collision(); - // Read buffered events from previous polls before calling a new poll - if let Some(event) = self.state.events.pop_front() { - return Ok(event); - } - - // this loop is necessary since self.incoming.pop_front() might return None. In that case, - // instead of returning a None event, we try again. - select! { - // Handles pending and new requests. - // If available, prioritises pending requests from previous session. - // Else, pulls next request from user requests channel. - // If conditions in the below branch are for flow control. - // The branch is disabled if there's no pending messages and new user requests - // cannot be serviced due flow control. - // We read next user user request only when inflight messages are < configured inflight - // and there are no collisions while handling previous outgoing requests. - // - // Flow control is based on ack count. If inflight packet count in the buffer is - // less than max_inflight setting, next outgoing request will progress. For this - // to work correctly, broker should ack in sequence (a lot of brokers won't) - // - // E.g If max inflight = 5, user requests will be blocked when inflight queue - // looks like this -> [1, 2, 3, 4, 5]. - // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. - // This pulls next user request. But because max packet id = max_inflight, next - // user request's packet id will roll to 1. This replaces existing packet id 1. - // Resulting in a collision - // - // Eventloop can stop receiving outgoing user requests when previous outgoing - // request collided. I.e collision state. Collision state will be cleared only - // when correct ack is received - // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. - // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. - // After collision with pkid 1 -> [1b ,2, x, 4, 5]. - // 1a is saved to state and event loop is set to collision mode stopping new - // outgoing requests (along with 1b). - Some(request) = self.pending.next(), if !inflight_full && !collision => { - self.state.handle_outgoing_packet(request)?; - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - }, - request = self.request_rx.recv_async(), if self.pending.is_empty() && !inflight_full && !collision => { - // Process first request - let request = request.map_err(|_| ConnectionError::RequestsDone)?; - self.state.handle_outgoing_packet(request)?; - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // Pull a bunch of packets from network, reply in bunch and yield the first item - packet = network.read() => { - let packet = packet?; - if let Some(packet) = self.state.handle_incoming_packet(packet)? { - network.write(packet).await?; + for _ in 0..self.batch_size { + let inflight_full = self.state.is_inflight_full(); + let collision = self.state.has_collision(); + + select! { + // Handles pending and new requests. + // If available, prioritises pending requests from previous session. + // Else, pulls next request from user requests channel. + // If conditions in the below branch are for flow control. + // The branch is disabled if there's no pending messages and new user requests + // cannot be serviced due flow control. + // We read next user user request only when inflight messages are < configured inflight + // and there are no collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this + // to work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue + // looks like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + Some(request) = self.pending.next(), if !inflight_full && !collision => { + if let Some(packet) = self.state.handle_outgoing_packet(request)? { + network.write(packet).await?; + } + }, + request = self.requests.recv_async(), if self.pending.is_empty() && !inflight_full && !collision => { + let request = request.map_err(|_| ConnectionError::RequestsDone)?; + if let Some(packet) = self.state.handle_outgoing_packet(request)? { + network.write(packet).await?; + } + }, + // Process next packet received from io + packet = network.read() => { + match packet? { + Some(packet) => if let Some(packet) = self.state.handle_incoming_packet(packet)? { + let flush = matches!(packet, Packet::PingResp(_)); + network.write(packet).await?; + if flush { + break; + } + } + None => return Err(ConnectionError::ConnectionClosed), + } + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + if let Some(packet) = self.state.handle_outgoing_packet(Request::PingReq)? { + network.write(packet).await?; + } } - // flush all the acks and return first incoming packet - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // We generate pings irrespective of network activity. This keeps the ping logic - // simple. We can change this behavior in future if necessary (to prevent extra pings) - _ = self.keepalive_timeout.as_mut().unwrap() => { - let timeout = self.keepalive_timeout.as_mut().unwrap(); - timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - - self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) + else => unreachable!("Eventloop select is exhaustive"), + } + + // Break early if there is no request pending and no more incoming bytes polled into the read buffer + // This implementation is suboptimal: The loop is *not* broken if a incomplete packets resides in the + // rx buffer of `Network`. Until that frame is complete the outgoing queue is *not* flushed. + // Since the incomplete packet is already started to appear in the buffer it should be fine to await + // more data on the stream before flushing. + if self.pending.is_empty() + && self.requests.is_empty() + && network.read_buffer_remaining() == 0 + { + break; } - else => unreachable!("Eventloop select is exhaustive"), } + + network.flush().await?; + + self.state + .events + .pop_front() + .ok_or_else(|| unreachable!("empty event queue")) } } -/// Pending requets yielded with a configured rate. If the queue is empty the stream will yield pending. -struct PendingRequests { +/// Pending items yielded with a configured rate. If the queue is empty the stream will yield pending. +struct IntervalQueue { /// Interval interval: Option, /// Pending requests - requests: VecDeque, + queue: VecDeque, } -impl PendingRequests { +impl IntervalQueue { + /// Construct a new Pending instance pub fn new(interval: Duration) -> Self { let interval = (!interval.is_zero()).then(|| time::interval(interval)); - PendingRequests { + IntervalQueue { interval, - requests: VecDeque::new(), + queue: VecDeque::new(), } } + /// Returns true this queue is not empty pub fn is_empty(&self) -> bool { - self.requests.is_empty() + self.queue.is_empty() } - pub fn extend(&mut self, requests: impl IntoIterator) { - self.requests.extend(requests); + /// Extend the request queue + pub fn extend(&mut self, requests: impl IntoIterator) { + self.queue.extend(requests); + } + + /// Reset the pending interval tick + pub fn reset(&mut self) { + if let Some(interval) = self.interval.as_mut() { + interval.reset(); + } } } -impl Stream for PendingRequests { - type Item = Request; +impl Stream for IntervalQueue { + type Item = T; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.is_empty() { Poll::Pending } else { match self.interval.as_mut() { Some(interval) => match interval.poll_tick(cx) { - Poll::Ready(_) => Poll::Ready(self.requests.pop_front()), + Poll::Ready(_) => Poll::Ready(self.queue.pop_front()), Poll::Pending => Poll::Pending, }, - None => Poll::Ready(self.requests.pop_front()), + None => Poll::Ready(self.queue.pop_front()), } } } fn size_hint(&self) -> (usize, Option) { - (self.requests.len(), Some(self.requests.len())) + (self.queue.len(), Some(self.queue.len())) } } @@ -298,12 +332,6 @@ async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), Conne // make MQTT connection request (which internally awaits for ack) let packet = mqtt_connect(options, &mut network).await?; - - // Last session might contain packets which aren't acked. MQTT says these packets should be - // republished in the next session - // move pending messages from state to eventloop - // let pending = self.state.clean(); - // self.pending = pending.into_iter(); Ok((network, packet)) } @@ -434,17 +462,20 @@ async fn mqtt_connect( // validate connack match network.read().await? { - Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { - // Override local keep_alive value if set by server. + Some(Incoming::ConnAck(connack)) if connack.code == ConnectReturnCode::Success => { if let Some(props) = &connack.properties { + // Override local keep_alive value if set by server. if let Some(keep_alive) = props.server_keep_alive { options.keep_alive = Duration::from_secs(keep_alive as u64); } + + // Override max packet size network.set_max_outgoing_size(props.max_packet_size); } Ok(Packet::ConnAck(connack)) } - Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)), - packet => Err(ConnectionError::NotConnAck(Box::new(packet))), + Some(Incoming::ConnAck(connack)) => Err(ConnectionError::ConnectionRefused(connack.code)), + Some(packet) => Err(ConnectionError::NotConnAck(Box::new(packet))), + None => Err(ConnectionError::ConnectionClosed), } } diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index eebfed566..fe069e75e 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,3 +1,4 @@ +use bytes::Buf; use futures_util::SinkExt; use tokio_stream::StreamExt; use tokio_util::codec::Framed; @@ -5,7 +6,7 @@ use tokio_util::codec::Framed; use crate::framed::AsyncReadWrite; use super::mqttbytes::v5::Packet; -use super::{mqttbytes, Codec, Connect, Login, MqttOptions}; +use super::{Codec, Connect, Login, MqttOptions}; use super::{Incoming, StateError}; /// Network transforms packets <-> frames efficiently. It takes @@ -31,13 +32,16 @@ impl Network { self.framed.codec_mut().max_outgoing_size = max_outgoing_size; } + pub fn read_buffer_remaining(&self) -> usize { + self.framed.read_buffer().remaining() + } + /// Reads and returns a single packet from network - pub async fn read(&mut self) -> Result { + pub async fn read(&mut self) -> Result, StateError> { match self.framed.next().await { - Some(Ok(packet)) => Ok(packet), - Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), + Some(Ok(packet)) => Ok(Some(packet)), Some(Err(e)) => Err(StateError::Deserialization(e)), - None => Err(StateError::ConnectionAborted), + None => Ok(None) } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 44499cde2..827f6e23f 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -78,8 +78,8 @@ pub struct MqttOptions { credentials: Option, /// request (publish, subscribe) channel capacity request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, + /// Max batch processing size + max_batch_size: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -126,7 +126,7 @@ impl MqttOptions { client_id: id.into(), credentials: None, request_channel_capacity: 10, - max_request_batch: 0, + max_batch_size: 10, pending_throttle: Duration::from_micros(0), last_will: None, conn_timeout: 5, @@ -654,12 +654,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") + if let Some(max_batch_size) = queries + .remove("max_batch_size") .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) .transpose()? { - options.max_request_batch = max_request_batch; + options.max_batch_size = max_batch_size; } if let Some(pending_throttle) = queries @@ -704,7 +704,7 @@ impl Debug for MqttOptions { .field("client_id", &self.client_id) .field("credentials", &self.credentials) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("max_request_batch", &self.max_batch_size) .field("pending_throttle", &self.pending_throttle) .field("last_will", &self.last_will) .field("conn_timeout", &self.conn_timeout) diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs index 76909d62d..832ceaefe 100644 --- a/rumqttc/src/v5/mqttbytes/v5/codec.rs +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -3,8 +3,8 @@ use tokio_util::codec::{Decoder, Encoder}; use super::{Error, Packet}; -/// MQTT v4 codec -#[derive(Debug, Clone)] +/// MQTT v5 codec +#[derive(Default, Debug, Clone)] pub struct Codec { /// Maximum packet size allowed by client pub max_incoming_size: Option, @@ -33,16 +33,14 @@ impl Encoder for Codec { type Error = Error; fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { - item.write(dst, self.max_outgoing_size)?; - - Ok(()) + item.write(dst, self.max_outgoing_size).map(drop) } } #[cfg(test)] mod tests { - use bytes::BytesMut; - use tokio_util::codec::Encoder; + use bytes::{Buf, BytesMut}; + use tokio_util::codec::{Decoder, Encoder}; use super::Codec; use crate::v5::{ @@ -73,4 +71,55 @@ mod tests { _ => unreachable!(), } } + + #[test] + fn encode_decode_multiple_packets() { + let mut buf = BytesMut::new(); + let mut codec = Codec::default(); + let publish = Packet::Publish(Publish::new( + "hello/world", + QoS::AtMostOnce, + vec![1; 10], + None, + )); + + // Encode a fixed number of publications into `buf` + for _ in 0..100 { + codec + .encode(publish.clone(), &mut buf) + .expect("failed to encode"); + } + + // Decode a fixed number of packets from `buf` + for _ in 0..100 { + let result = codec.decode(&mut buf).expect("failed to encode"); + assert!(matches!(result, Some(p) if p == publish)); + } + + assert_eq!(buf.remaining(), 0); + } + + #[test] + fn decode_insufficient() { + let mut buf = BytesMut::new(); + let mut codec = Codec::default(); + let publish = Packet::Publish(Publish::new( + "hello/world", + QoS::AtMostOnce, + vec![1; 100], + None, + )); + + // Encode packet into `buf` + codec + .encode(publish.clone(), &mut buf) + .expect("failed to encode"); + let result = codec.decode(&mut buf); + assert!(matches!(result, Ok(Some(p)) if p == publish)); + + buf.resize(buf.remaining() / 2, 0); + + let result = codec.decode(&mut buf); + assert!(matches!(result, Ok(None))); + } } From e3893543b970405f3f7c0bc1fb56d9a03d0763e9 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Tue, 26 Mar 2024 16:38:07 +0100 Subject: [PATCH 08/12] feat(rumqttc): simplify keepalive interval --- rumqttc/src/v5/eventloop.rs | 35 +++++++++++++++++------------------ rumqttc/src/v5/framed.rs | 4 ++-- rumqttc/src/v5/state.rs | 2 +- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index e109a6bf4..e96ace350 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -7,7 +7,8 @@ use crate::framed::AsyncReadWrite; use flume::Receiver; use futures_util::{Stream, StreamExt}; use tokio::select; -use tokio::time::{self, error::Elapsed, Instant, Sleep}; +use tokio::time::Interval; +use tokio::time::{self, error::Elapsed}; use std::collections::VecDeque; use std::io; @@ -85,7 +86,7 @@ pub struct EventLoop { /// Network connection to the broker network: Option, /// Keep alive time - keepalive_timeout: Option>>, + keepalive_interval: Interval, } /// Events which can be yielded by the event loop @@ -103,6 +104,8 @@ impl EventLoop { let pending = IntervalQueue::new(options.pending_throttle); let batch_size = options.max_batch_size; let state = MqttState::new(inflight_limit, manual_acks); + assert!(!options.keep_alive.is_zero()); + let keepalive_interval = time::interval(options.keep_alive()); EventLoop { options, @@ -111,7 +114,7 @@ impl EventLoop { requests, pending, network: None, - keepalive_timeout: None, + keepalive_interval, } } @@ -124,7 +127,6 @@ impl EventLoop { /// > For this reason we recommend setting [`AsycClient`](super::AsyncClient)'s channel capacity to `0`. pub fn clean(&mut self) { self.network = None; - self.keepalive_timeout = None; self.pending.extend(self.state.clean()); // drain requests from channel which weren't yet received @@ -144,14 +146,12 @@ impl EventLoop { .await??; self.network = Some(network); - if self.keepalive_timeout.is_none() { - self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); - } - // A connack never produces a response packet. Safe to ignore the return value // of `handle_incoming_packet` self.state.handle_incoming_packet(connack)?; - self.pending.reset(); + + self.pending.reset_immediately(); + self.keepalive_interval.reset(); } // Read buffered events from previous polls before calling a new poll @@ -216,8 +216,10 @@ impl EventLoop { network.write(packet).await?; } }, - // Process next packet received from io + // Process next packet from io packet = network.read() => { + // Reset keepalive interval due to packet reception + self.keepalive_interval.reset(); match packet? { Some(packet) => if let Some(packet) = self.state.handle_incoming_packet(packet)? { let flush = matches!(packet, Packet::PingResp(_)); @@ -229,11 +231,8 @@ impl EventLoop { None => return Err(ConnectionError::ConnectionClosed), } }, - // We generate pings irrespective of network activity. This keeps the ping logic - // simple. We can change this behavior in future if necessary (to prevent extra pings) - _ = self.keepalive_timeout.as_mut().unwrap() => { - let timeout = self.keepalive_timeout.as_mut().unwrap(); - timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + // Send a ping request on each interval tick + _ = self.keepalive_interval.tick() => { if let Some(packet) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(packet).await?; } @@ -291,10 +290,10 @@ impl IntervalQueue { self.queue.extend(requests); } - /// Reset the pending interval tick - pub fn reset(&mut self) { + /// Reset the pending interval tick. Next tick yields immediately + pub fn reset_immediately(&mut self) { if let Some(interval) = self.interval.as_mut() { - interval.reset(); + interval.reset_immediately(); } } } diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index fe069e75e..29d357de5 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -6,7 +6,7 @@ use tokio_util::codec::Framed; use crate::framed::AsyncReadWrite; use super::mqttbytes::v5::Packet; -use super::{Codec, Connect, Login, MqttOptions}; +use super::{Codec, Connect, MqttOptions}; use super::{Incoming, StateError}; /// Network transforms packets <-> frames efficiently. It takes @@ -41,7 +41,7 @@ impl Network { match self.framed.next().await { Some(Ok(packet)) => Ok(Some(packet)), Some(Err(e)) => Err(StateError::Deserialization(e)), - None => Ok(None) + None => Ok(None), } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 95201bde6..7b1ae8ac8 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -65,7 +65,7 @@ pub enum StateError { #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, #[error("Connection closed by peer abruptly")] - ConnectionAborted + ConnectionAborted, } impl From for StateError { From f33c704c338b8bc43d010d90b0243ebf7afeb97f Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Wed, 27 Mar 2024 07:51:46 +0100 Subject: [PATCH 09/12] fix(rumqttc): create v5 eventloop within the tokio runtime Creating certain instances part of event loop needs to happen inside the target tokio runtime. --- rumqttc/src/v5/client.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 7d8a36bc8..6b56640f5 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -479,15 +479,15 @@ impl Client { /// /// `cap` specifies the capacity of the bounded async channel. pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { - let (client, eventloop) = AsyncClient::new(options, cap); - let client = Client { client }; - let runtime = runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); + let (client, eventloop) = runtime.block_on(async { AsyncClient::new(options, cap) }); + let client = Client { client }; let connection = Connection::new(eventloop, runtime); + (client, connection) } From 9dc09d0700c52a92491f69164c6015bc49ea8ebd Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Wed, 27 Mar 2024 09:28:12 +0100 Subject: [PATCH 10/12] feat(rumqttc): add connection timeout Add timeouts on network read and write in the v4 client. Partial cleanup of the options: use references where possible instead of copying the network options. Use std::time::Duration instead of u32 secs (align with pending throttle). --- Cargo.lock | 36 +++++++++++ rumqttc/Cargo.toml | 1 + rumqttc/src/eventloop.rs | 18 +++--- rumqttc/src/lib.rs | 19 +++--- rumqttc/src/proxy.rs | 2 +- rumqttc/src/v5/eventloop.rs | 124 +++++++++++++++++++++++++++--------- rumqttc/src/v5/framed.rs | 20 ++---- rumqttc/src/v5/mod.rs | 33 ++++++---- 8 files changed, 174 insertions(+), 79 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b7396d431..3ea161459 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,6 +150,28 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + [[package]] name = "async-trait" version = "0.1.77" @@ -1952,6 +1974,7 @@ dependencies = [ "tokio-native-tls", "tokio-rustls", "tokio-stream", + "tokio-test", "tokio-util", "url", "ws_stream_tungstenite", @@ -2557,6 +2580,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index bba64822c..b3a3f2165 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -57,6 +57,7 @@ matches = "0.1" pretty_assertions = "1" pretty_env_logger = "0.5" serde = { version = "1", features = ["derive"] } +tokio-test = "0.4.4" [[example]] name = "tls" diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index a9b1ce8c5..621a2830a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -141,8 +141,8 @@ impl EventLoop { pub async fn poll(&mut self) -> Result { if self.network.is_none() { let (network, connack) = match time::timeout( - Duration::from_secs(self.network_options.connection_timeout()), - connect(&self.mqtt_options, self.network_options.clone()), + self.network_options.connection_timeout(), + connect(&self.mqtt_options, self.network_options()), ) .await { @@ -173,7 +173,7 @@ impl EventLoop { // let await_acks = self.state.await_acks; let inflight_full = self.state.inflight >= self.mqtt_options.inflight; let collision = self.state.collision.is_some(); - let network_timeout = Duration::from_secs(self.network_options.connection_timeout()); + let network_timeout = self.network_options.connection_timeout(); // Read buffered events from previous polls before calling a new poll if let Some(event) = self.state.events.pop_front() { @@ -258,10 +258,12 @@ impl EventLoop { } } - pub fn network_options(&self) -> NetworkOptions { - self.network_options.clone() + /// Get network options + pub fn network_options(&self) -> &NetworkOptions { + &self.network_options } + /// Set network options pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self { self.network_options = network_options; self @@ -293,7 +295,7 @@ impl EventLoop { /// between re-connections so that cancel semantics can be used during this sleep async fn connect( mqtt_options: &MqttOptions, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> Result<(Network, Incoming), ConnectionError> { // connect to the broker let mut network = network_connect(mqtt_options, network_options).await?; @@ -306,7 +308,7 @@ async fn connect( pub(crate) async fn socket_connect( host: String, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> io::Result { let addrs = lookup_host(host).await?; let mut last_err = None; @@ -352,7 +354,7 @@ pub(crate) async fn socket_connect( async fn network_connect( options: &MqttOptions, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> Result { // Process Unix files early, as proxy is not supported for them. #[cfg(unix)] diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 29cad1a34..566c58bdd 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -365,11 +365,12 @@ impl From for TlsConfiguration { } /// Provides a way to configure low level network connection configurations -#[derive(Clone, Default)] +#[derive(Clone, Debug, Default)] pub struct NetworkOptions { tcp_send_buffer_size: Option, tcp_recv_buffer_size: Option, - conn_timeout: u64, + /// Connection timeout + connection_timeout: Duration, #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] bind_device: Option, } @@ -379,7 +380,7 @@ impl NetworkOptions { NetworkOptions { tcp_send_buffer_size: None, tcp_recv_buffer_size: None, - conn_timeout: 5, + connection_timeout: Duration::from_secs(5), #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] bind_device: None, } @@ -393,15 +394,15 @@ impl NetworkOptions { self.tcp_recv_buffer_size = Some(size); } - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; + /// Set connection timeout + pub fn set_connection_timeout(&mut self, timeout: Duration) -> &mut Self { + self.connection_timeout = timeout; self } - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout + /// Get connection timeout + pub fn connection_timeout(&self) -> Duration { + self.connection_timeout } /// bind connection to a specific network device by name diff --git a/rumqttc/src/proxy.rs b/rumqttc/src/proxy.rs index 94c7aabd3..4d976df99 100644 --- a/rumqttc/src/proxy.rs +++ b/rumqttc/src/proxy.rs @@ -45,7 +45,7 @@ impl Proxy { self, broker_addr: &str, broker_port: u16, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> Result, ProxyError> { let proxy_addr = format!("{}:{}", self.addr, self.port); diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index e96ace350..ac338b16f 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,22 +1,23 @@ -use super::framed::Network; -use super::mqttbytes::v5::*; -use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; -use crate::eventloop::socket_connect; -use crate::framed::AsyncReadWrite; +use super::{ + framed::Network, mqttbytes::v5::ConnectReturnCode, mqttbytes::v5::*, Incoming, MqttOptions, + MqttState, Outgoing, Request, StateError, Transport, +}; +use crate::{eventloop::socket_connect, framed::AsyncReadWrite}; use flume::Receiver; use futures_util::{Stream, StreamExt}; -use tokio::select; -use tokio::time::Interval; -use tokio::time::{self, error::Elapsed}; - -use std::collections::VecDeque; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; +use tokio::{ + select, + time::{self, error::Elapsed, timeout, Interval}, +}; -use super::mqttbytes::v5::ConnectReturnCode; +use std::{ + collections::VecDeque, + io, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] use crate::tls; @@ -77,8 +78,6 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, - /// Batch size - batch_size: usize, /// Request stream requests: Receiver, /// Pending requests from the last session @@ -102,7 +101,6 @@ impl EventLoop { let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; let pending = IntervalQueue::new(options.pending_throttle); - let batch_size = options.max_batch_size; let state = MqttState::new(inflight_limit, manual_acks); assert!(!options.keep_alive.is_zero()); let keepalive_interval = time::interval(options.keep_alive()); @@ -110,7 +108,6 @@ impl EventLoop { EventLoop { options, state, - batch_size, requests, pending, network: None, @@ -139,11 +136,8 @@ impl EventLoop { /// **NOTE** Don't block this while iterating pub async fn poll(&mut self) -> Result { if self.network.is_none() { - let (network, connack) = time::timeout( - Duration::from_secs(self.options.connection_timeout()), - connect(&mut self.options), - ) - .await??; + let connect_timeout = self.options.connect_timeout(); + let (network, connack) = timeout(connect_timeout, connect(&mut self.options)).await??; self.network = Some(network); // A connack never produces a response packet. Safe to ignore the return value @@ -171,8 +165,9 @@ impl EventLoop { /// Select on network and requests and generate keepalive pings when necessary async fn poll_process(&mut self) -> Result { let network = self.network.as_mut().unwrap(); + let network_timeout = self.options.network_options().connection_timeout(); - for _ in 0..self.batch_size { + for _ in 0..self.options.max_batch_size { let inflight_full = self.state.is_inflight_full(); let collision = self.state.has_collision(); @@ -207,13 +202,13 @@ impl EventLoop { // outgoing requests (along with 1b). Some(request) = self.pending.next(), if !inflight_full && !collision => { if let Some(packet) = self.state.handle_outgoing_packet(request)? { - network.write(packet).await?; + timeout(network_timeout, network.write(packet)).await??; } }, request = self.requests.recv_async(), if self.pending.is_empty() && !inflight_full && !collision => { let request = request.map_err(|_| ConnectionError::RequestsDone)?; if let Some(packet) = self.state.handle_outgoing_packet(request)? { - network.write(packet).await?; + timeout(network_timeout, network.write(packet)).await??; } }, // Process next packet from io @@ -223,7 +218,7 @@ impl EventLoop { match packet? { Some(packet) => if let Some(packet) = self.state.handle_incoming_packet(packet)? { let flush = matches!(packet, Packet::PingResp(_)); - network.write(packet).await?; + timeout(network_timeout, network.write(packet)).await??; if flush { break; } @@ -234,7 +229,7 @@ impl EventLoop { // Send a ping request on each interval tick _ = self.keepalive_interval.tick() => { if let Some(packet) = self.state.handle_outgoing_packet(Request::PingReq)? { - network.write(packet).await?; + timeout(network_timeout, network.write(packet)).await??; } } else => unreachable!("Eventloop select is exhaustive"), @@ -253,7 +248,7 @@ impl EventLoop { } } - network.flush().await?; + timeout(network_timeout, network.flush()).await??; self.state .events @@ -447,6 +442,7 @@ async fn mqtt_connect( let keep_alive = options.keep_alive().as_secs() as u16; let clean_start = options.clean_start(); let client_id = options.client_id(); + let connect_timeout = options.connect_timeout(); let properties = options.connect_properties(); let connect = Connect { @@ -457,7 +453,14 @@ async fn mqtt_connect( }; // send mqtt connect packet - network.connect(connect, options).await?; + let last_will = options.last_will(); + let login = options.credentials(); + let connect = Packet::Connect(connect, last_will, login); + timeout(connect_timeout, async { + network.write(connect).await?; + network.flush().await + }) + .await??; // validate connack match network.read().await? { @@ -478,3 +481,62 @@ async fn mqtt_connect( None => Err(ConnectionError::ConnectionClosed), } } + +#[tokio::test(start_paused = true)] +async fn connect_and_receive_connack() { + let mut options = MqttOptions::new("", "", 0); + + // Prepare a connect packet that is expected to be received. + let mut connect = bytes::BytesMut::new(); + Packet::Connect( + Connect { + keep_alive: options.keep_alive().as_secs() as u16, + client_id: options.client_id(), + clean_start: options.clean_start(), + properties: options.connect_properties(), + }, + options.last_will(), + options.credentials(), + ) + .write(&mut connect, None) + .ok(); + + // Prepare connect ack + let mut connect_ack = bytes::BytesMut::new(); + Packet::ConnAck(ConnAck { + session_present: false, + code: ConnectReturnCode::Success, + properties: None, + }) + .write(&mut connect_ack, None) + .ok(); + + // IO will assume a connect packet and *not* reply with a connack. + let io = tokio_test::io::Builder::new() + .write(&connect) + .read(&connect_ack) + .build(); + let mut network = Network::new(io, None); + + // Operation should timeout because io flush will not resolve. + let result = mqtt_connect(&mut options, &mut network).await; + + assert!(matches!(dbg!(result), Ok(Packet::ConnAck(ConnAck { .. })))); +} + +#[tokio::test(start_paused = true)] +async fn connect_timeouts_connect_packet_write() { + let mut options = MqttOptions::new("", "", 0); + options.set_connect_timeout(Duration::from_secs(10)); + + // IO will not accept the connect packet write + let io = tokio_test::io::Builder::new() + .wait(Duration::from_secs(30)) + .build(); + let mut network = Network::new(io, None); + + // Operation should timeout because io flush will not resolve. + let result = mqtt_connect(&mut options, &mut network).await; + + assert!(matches!(result, Err(ConnectionError::Timeout(_)))); +} diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 29d357de5..ced4ebb62 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -5,9 +5,10 @@ use tokio_util::codec::Framed; use crate::framed::AsyncReadWrite; -use super::mqttbytes::v5::Packet; -use super::{Codec, Connect, MqttOptions}; -use super::{Incoming, StateError}; +use super::{ + mqttbytes::v5::Packet, + Codec, {Incoming, StateError}, +}; /// Network transforms packets <-> frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when @@ -60,17 +61,4 @@ impl Network { .await .map_err(StateError::Deserialization) } - - pub async fn connect( - &mut self, - connect: Connect, - options: &MqttOptions, - ) -> Result<(), StateError> { - let last_will = options.last_will(); - let login = options.credentials(); - self.write(Packet::Connect(connect, last_will, login)) - .await?; - - self.flush().await - } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 827f6e23f..dc38eb57d 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -85,8 +85,8 @@ pub struct MqttOptions { pending_throttle: Duration, /// Last will that will be issued on unexpected disconnect last_will: Option, - /// Connection timeout - conn_timeout: u64, + /// Connect timeout + connect_timeout: Duration, /// Default value of for maximum incoming packet size. /// Used when `max_incomming_size` in `connect_properties` is NOT available. default_max_incoming_size: u32, @@ -95,6 +95,7 @@ pub struct MqttOptions { /// If set to `true` MQTT acknowledgements are not sent automatically. /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. manual_acks: bool, + /// Network options network_options: NetworkOptions, #[cfg(feature = "proxy")] /// Proxy configuration. @@ -129,7 +130,7 @@ impl MqttOptions { max_batch_size: 10, pending_throttle: Duration::from_micros(0), last_will: None, - conn_timeout: 5, + connect_timeout: Duration::from_secs(5), default_max_incoming_size: 10 * 1024, connect_properties: None, manual_acks: false, @@ -290,15 +291,15 @@ impl MqttOptions { self.pending_throttle } - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; + /// set connect timeout + pub fn set_connect_timeout(&mut self, timeout: Duration) -> &mut Self { + self.connect_timeout = timeout; self } - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout + /// get connect timeout + pub fn connect_timeout(&self) -> Duration { + self.connect_timeout } /// set connection properties @@ -494,10 +495,12 @@ impl MqttOptions { self.manual_acks } - pub fn network_options(&self) -> NetworkOptions { - self.network_options.clone() + /// get network options + pub fn network_options(&self) -> &NetworkOptions { + &self.network_options } + /// set network options pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self { self.network_options = network_options; self @@ -676,11 +679,12 @@ impl std::convert::TryFrom for MqttOptions { .transpose()?; if let Some(conn_timeout) = queries - .remove("conn_timeout_secs") + .remove("connect_timeout_secs") .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) .transpose()? { - options.set_connection_timeout(conn_timeout); + let conn_timeout = Duration::from_secs(conn_timeout); + options.set_connect_timeout(conn_timeout); } if let Some((opt, _)) = queries.into_iter().next() { @@ -707,8 +711,9 @@ impl Debug for MqttOptions { .field("max_request_batch", &self.max_batch_size) .field("pending_throttle", &self.pending_throttle) .field("last_will", &self.last_will) - .field("conn_timeout", &self.conn_timeout) + .field("connect_timeout", &self.connect_timeout) .field("manual_acks", &self.manual_acks) + .field("network_options", &self.network_options) .field("connect properties", &self.connect_properties) .finish() } From 1271a878bb106b08ddeaa4a8fce61689964be2af Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Wed, 27 Mar 2024 09:38:03 +0100 Subject: [PATCH 11/12] fix(rumqttc): remove unsued error variants Remove the `Io` variant from `StateError` because this type of error is also present in `ConnectionError` and fits better there. Same for `ConnectionAborted` which is covered with `ConnectionError::ConnectionClosed`. --- rumqttc/src/v5/state.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 7b1ae8ac8..ef1e4939c 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -9,14 +9,11 @@ use super::{Event, Incoming, Outgoing, Request}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; -use std::{io, time::Instant}; +use std::time::Instant; /// Errors during state handling #[derive(Debug, thiserror::Error)] pub enum StateError { - /// Io Error while state is passed to network - #[error("Io error: {0:?}")] - Io(#[from] io::Error), #[error("Conversion error {0:?}")] Coversion(#[from] core::num::TryFromIntError), /// Invalid state for a given operation @@ -64,8 +61,6 @@ pub enum StateError { PubCompFail { reason: PubCompReason }, #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, - #[error("Connection closed by peer abruptly")] - ConnectionAborted, } impl From for StateError { From 5a18ebef9eda648b12c2caa2ca5a558f33bc2159 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Wed, 27 Mar 2024 10:16:59 +0100 Subject: [PATCH 12/12] fix(rumqttc): fix from_url client option parsing --- rumqttc/src/lib.rs | 20 ++++++++++---------- rumqttc/src/v5/mod.rs | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 566c58bdd..7e28c4d3b 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -444,7 +444,7 @@ pub struct MqttOptions { /// request (publish, subscribe) channel capacity request_channel_capacity: usize, /// Max internal request batching - max_request_batch: usize, + max_batch_size: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -484,7 +484,7 @@ impl MqttOptions { max_incoming_packet_size: 10 * 1024, max_outgoing_packet_size: 10 * 1024, request_channel_capacity: 10, - max_request_batch: 0, + max_batch_size: 0, pending_throttle: Duration::from_micros(0), inflight: 100, last_will: None, @@ -735,7 +735,7 @@ pub enum OptionError { RequestChannelCapacity, #[error("Invalid max-request-batch value.")] - MaxRequestBatch, + MaxBatchSize, #[error("Invalid pending-throttle value.")] PendingThrottle, @@ -843,12 +843,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") - .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + if let Some(max_batch_size) = queries + .remove("max_batch_size") + .map(|v| v.parse::().map_err(|_| OptionError::MaxBatchSize)) .transpose()? { - options.max_request_batch = max_request_batch; + options.max_batch_size = max_batch_size; } if let Some(pending_throttle) = queries @@ -888,7 +888,7 @@ impl Debug for MqttOptions { .field("credentials", &self.credentials) .field("max_packet_size", &self.max_incoming_packet_size) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("max_batch_size", &self.max_batch_size) .field("pending_throttle", &self.pending_throttle) .field("inflight", &self.inflight) .field("last_will", &self.last_will) @@ -971,8 +971,8 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), - OptionError::MaxRequestBatch + err("mqtt://host:42?client_id=foo&max_batch_size=foo"), + OptionError::MaxBatchSize ); assert_eq!( err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index dc38eb57d..f71f66330 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -790,7 +790,7 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + err("mqtt://host:42?client_id=foo&max_batch_size=foo"), OptionError::MaxRequestBatch ); assert_eq!( @@ -802,7 +802,7 @@ mod test { OptionError::Inflight ); assert_eq!( - err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + err("mqtt://host:42?client_id=foo&connect_timeout_secs=foo"), OptionError::ConnTimeout ); }