diff --git a/roslibrust/src/ros1/node/actor.rs b/roslibrust/src/ros1/node/actor.rs index 296a816..4ec61e7 100644 --- a/roslibrust/src/ros1/node/actor.rs +++ b/roslibrust/src/ros1/node/actor.rs @@ -44,7 +44,7 @@ pub enum NodeMsg { // This results in the node's task ending and the node being dropped. Shutdown, RegisterPublisher { - reply: oneshot::Sender>, String>>, + reply: oneshot::Sender>, mpsc::Sender<()>), String>>, topic: String, topic_type: String, queue_size: usize, @@ -166,7 +166,7 @@ impl NodeServerHandle { topic: &str, queue_size: usize, latching: bool, - ) -> Result>, NodeError> { + ) -> Result<(broadcast::Sender>, mpsc::Sender<()>), NodeError> { let (sender, receiver) = oneshot::channel(); self.node_server_sender.send(NodeMsg::RegisterPublisher { reply: sender, @@ -192,7 +192,7 @@ impl NodeServerHandle { msg_definition: &str, queue_size: usize, latching: bool, - ) -> Result>, NodeError> { + ) -> Result<(broadcast::Sender>, mpsc::Sender<()>), NodeError> { let (sender, receiver) = oneshot::channel(); let md5sum; @@ -693,33 +693,43 @@ impl Node { msg_definition: String, md5sum: String, latching: bool, - ) -> Result>, NodeError> { + ) -> Result<(broadcast::Sender>, mpsc::Sender<()>), NodeError> { // Return handle to existing Publication if it exists let existing_entry = { self.publishers.iter().find_map(|(key, value)| { - if key.as_str() == &topic { - if value.topic_type() == topic_type { - let sender = value.get_sender(); - return Some(Ok(sender)); - } else { - warn!("Attempted to register publisher with different topic type than existing publisher: existing_type={}, new_type={}", value.topic_type(), topic_type); + if key.as_str() != &topic { + return None; + } + if value.topic_type() != topic_type { + warn!("Attempted to register publisher with different topic type than existing publisher: existing_type={}, new_type={}", value.topic_type(), topic_type); + // TODO MAJOR: this is a terrible error type to return... + return Some(Err(NodeError::IoError(std::io::Error::from( + std::io::ErrorKind::AddrInUse, + )))); + } + let (sender, shutdown) = value.get_senders(); + match shutdown.upgrade() { + Some(shutdown) => { + Some(Ok((sender, shutdown))) + } + None => { + error!("We still have an entry for a publication, but it has been shutdown"); // TODO MAJOR: this is a terrible error type to return... Some(Err(NodeError::IoError(std::io::Error::from( std::io::ErrorKind::AddrInUse, )))) } - } else { - None } }) }; // If we found an existing publication return the handle to it if let Some(handle) = existing_entry { - return Ok(handle?); + let (sender, shutdown) = handle?; + return Ok((sender, shutdown)); } // Otherwise create a new Publication and advertise - let (channel, sender) = Publication::new( + let (channel, sender, shutdown) = Publication::new( &self.node_name, latching, &topic, @@ -737,7 +747,7 @@ impl Node { })?; self.publishers.insert(topic.clone(), channel); let _ = self.client.register_publisher(&topic, topic_type).await?; - Ok(sender) + Ok((sender, shutdown)) } async fn unregister_publisher(&mut self, topic: &str) -> Result<(), NodeError> { diff --git a/roslibrust/src/ros1/node/handle.rs b/roslibrust/src/ros1/node/handle.rs index 3d2224b..f58d135 100644 --- a/roslibrust/src/ros1/node/handle.rs +++ b/roslibrust/src/ros1/node/handle.rs @@ -83,11 +83,11 @@ impl NodeHandle { queue_size: usize, latching: bool, ) -> Result { - let sender = self + let (sender, shutdown) = self .inner .register_publisher_any(topic_name, topic_type, msg_definition, queue_size, latching) .await?; - Ok(PublisherAny::new(topic_name, sender)) + Ok(PublisherAny::new(topic_name, sender, shutdown)) } /// Create a new publisher for the given type. @@ -103,11 +103,11 @@ impl NodeHandle { queue_size: usize, latching: bool, ) -> Result, NodeError> { - let sender = self + let (sender, shutdown) = self .inner .register_publisher::(topic_name, queue_size, latching) .await?; - Ok(Publisher::new(topic_name, sender)) + Ok(Publisher::new(topic_name, sender, shutdown)) } pub async fn subscribe_any( diff --git a/roslibrust/src/ros1/publisher.rs b/roslibrust/src/ros1/publisher.rs index 6de98bf..c2b53e5 100644 --- a/roslibrust/src/ros1/publisher.rs +++ b/roslibrust/src/ros1/publisher.rs @@ -1,6 +1,5 @@ use crate::ros1::{ names::Name, - node::actor::NodeServerHandle, tcpros::{self, ConnectionHeader}, }; use abort_on_drop::ChildTask; @@ -9,28 +8,37 @@ use roslibrust_codegen::RosMessageType; use std::{ marker::PhantomData, net::{Ipv4Addr, SocketAddr}, - sync::Arc, }; use tokio::{ io::AsyncWriteExt, - sync::{ - broadcast::{self, error::RecvError}, - RwLock, - }, + sync::broadcast::{self, error::RecvError}, }; +use super::actor::NodeServerHandle; + /// The regular Publisher representation returned by calling advertise on a [crate::ros1::NodeHandle]. pub struct Publisher { + // Name of the topic this publisher is publishing on topic_name: String, + // Actual channel on which messages are sent to be published sender: broadcast::Sender>, + // When the last publisher for a given topic is dropped, this channel is used to signal to cleanup + // for the underlying publication + _shutdown_channel: tokio::sync::mpsc::Sender<()>, + // Phantom data to ensure that the type is known at compile time phantom: PhantomData, } impl Publisher { - pub(crate) fn new(topic_name: &str, sender: broadcast::Sender>) -> Self { + pub(crate) fn new( + topic_name: &str, + sender: broadcast::Sender>, + shutdown_channel: tokio::sync::mpsc::Sender<()>, + ) -> Self { Self { topic_name: topic_name.to_owned(), sender, + _shutdown_channel: shutdown_channel, phantom: PhantomData, } } @@ -58,14 +66,23 @@ impl Publisher { pub struct PublisherAny { topic_name: String, sender: broadcast::Sender>, + // When the last publisher for a given topic is dropped, this channel is used to signal to cleanup + // Don't need to send a message, simply dropping the last handle lets to node know to clean up + // Note: this has to be used because tokio::sync::broadcast doesn't have a WeakSender + _shutdown: tokio::sync::mpsc::Sender<()>, phantom: PhantomData>, } impl PublisherAny { - pub(crate) fn new(topic_name: &str, sender: broadcast::Sender>) -> Self { + pub(crate) fn new( + topic_name: &str, + sender: broadcast::Sender>, + shutdown: tokio::sync::mpsc::Sender<()>, + ) -> Self { Self { topic_name: topic_name.to_owned(), sender, + _shutdown: shutdown, phantom: PhantomData, } } @@ -94,8 +111,11 @@ pub(crate) struct Publication { topic_type: String, listener_port: u16, _tcp_accept_task: ChildTask<()>, - // TODO: Need to make sure this isn't keeping things alive that it shouldn't publish_sender: broadcast::Sender>, + // We store a weak handle to the shutdown channel + // This allows us to create new Publisher with a shutdown sender, but doesn't keep the shutdown channel alive + // Had to add this because broadcast doesn't have a weak sender equivalent + weak_shutdown_channel: tokio::sync::mpsc::WeakSender<()>, } impl Publication { @@ -112,7 +132,14 @@ impl Publication { md5sum: &str, topic_type: &str, node_handle: NodeServerHandle, - ) -> Result<(Self, broadcast::Sender>), std::io::Error> { + ) -> Result< + ( + Self, + broadcast::Sender>, + tokio::sync::mpsc::Sender<()>, + ), + std::io::Error, + > { // Get a socket for receiving connections on let host_addr = SocketAddr::from((host_addr, 0)); let tcp_listener = tokio::net::TcpListener::bind(host_addr).await?; @@ -121,7 +148,7 @@ impl Publication { // Setup the channel will will receive messages to be published on let (sender, receiver) = broadcast::channel::>(queue_size); - // Setup the ROS connection header that we'll respond to all incomming connections with + // Setup the ROS connection header that we'll respond to all incoming connections with let responding_conn_header = ConnectionHeader { caller_id: node_name.to_string(), latching, @@ -134,24 +161,20 @@ impl Publication { }; trace!("Publisher connection header: {responding_conn_header:?}"); - // Setup storage for internal list of TCP streams - let subscriber_streams = Arc::new(RwLock::new(Vec::new())); - - // Setup storage for the last message published (used for latching) - let last_message = Arc::new(RwLock::new(None)); + // Setup a channel to signal to the publication to clean itself up + let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1); + let weak_shutdown_channel = shutdown_tx.downgrade(); // Create the task that will accept new TCP connections - let subscriber_streams_copy = subscriber_streams.clone(); - let last_message_copy = last_message.clone(); let topic_name_copy = topic_name.to_owned(); let tcp_accept_handle = tokio::spawn(async move { Self::tcp_accept_task( tcp_listener, - subscriber_streams_copy, topic_name_copy, responding_conn_header, - last_message_copy, receiver, + shutdown_rx, + node_handle, ) .await }); @@ -163,13 +186,23 @@ impl Publication { _tcp_accept_task: tcp_accept_handle.into(), listener_port, publish_sender: sender, + weak_shutdown_channel, }, sender_copy, + shutdown_tx, )) } - pub(crate) fn get_sender(&self) -> broadcast::Sender> { - self.publish_sender.clone() + pub(crate) fn get_senders( + &self, + ) -> ( + broadcast::Sender>, + tokio::sync::mpsc::WeakSender<()>, + ) { + ( + self.publish_sender.clone(), + self.weak_shutdown_channel.clone(), + ) } pub(crate) fn port(&self) -> u16 { @@ -188,17 +221,35 @@ impl Publication { mut rx: broadcast::Receiver>, // Receives messages to publish from the main buffer of messages mut stream: tokio::net::TcpStream, topic: String, + last_message: Option>, // If we're latching will contain a message to send right away ) { let peer = stream.peer_addr(); debug!("Publish task has started for publication: {topic} connection to {peer:?}"); + + if let Some(last_message) = last_message { + let res = stream.write_all(&last_message).await; + match res { + Ok(_) => {} + Err(e) => { + error!("Failed to send latch message to subscriber: {e:?}"); + } + } + } + loop { match rx.recv().await { Ok(msg_to_publish) => { trace!("Publish task got message to publish for topic: {topic}"); - // Proxy the message to the watch channel - if let Err(err) = stream.write_all(&msg_to_publish[..]).await { - // TODO: A single failure between nodes that cross host boundaries is probably normal, should make this more robust perhaps - debug!("Failed to send data to subscriber: {err}, removing"); + let send_result = stream.write_all(&msg_to_publish[..]).await; + match send_result { + Ok(_) => { + trace!("Publish task sent message to topic: {topic}"); + } + Err(err) => { + // Shut down this TCP connection if we can't write a whole message + debug!("Failed to send data to subscriber: {err}, removing"); + break; + } } } Err(RecvError::Lagged(num)) => { @@ -206,24 +257,7 @@ impl Publication { continue; } Err(RecvError::Closed) => { - debug!( - "No more senders for the publisher channel, triggering publication cleanup" - ); - // TODO SHIT SHOW HERE - // broadcast stuffs breaks our cleanup plan - - // All senders dropped, so we want to cleanup the Publication off of the node - // Tell the node server to dispose of this publication and unadvertise it - // Note: we need to do this in a spawned task or a drop-loop race condition will occur - // Dropping publication results in this task being dropped, which can end up canceling the future that is doing the dropping - // if we simply .await here - // TODO: This allows publisher to clean themselves up iff node remains running after publisher is dropped... - // NodeHandle clean-up is not resulting in a good state clean-up currently.. - // let nh_copy = node_handle.clone(); - // let topic = topic.clone(); - // tokio::spawn(async move { - // let _ = nh_copy.unregister_publisher(&topic).await; - // }); + debug!("No more senders for the publisher channel, ending task"); break; } } @@ -236,106 +270,128 @@ impl Publication { /// This task constantly accepts new TCP connections and adds them to the list of streams to send data to. async fn tcp_accept_task( tcp_listener: tokio::net::TcpListener, // The TCP listener to accept connections on - subscriber_streams: Arc>>, // Where accepted streams are stored - topic_name: String, // Only used for logging - responding_conn_header: ConnectionHeader, // Header we respond with - last_message: Arc>>>, // Last message published (used for latching) - rx: broadcast::Receiver>, // Receives messages to publish from the main buffer of messages + topic_name: String, // Only used for logging + responding_conn_header: ConnectionHeader, // Header we respond with + mut rx: broadcast::Receiver>, // Receives messages to publish from the main buffer of messages + mut shutdown_rx: tokio::sync::mpsc::Receiver<()>, // Channel to signal to the publication to clean itself up + nh: NodeServerHandle, ) { debug!("TCP accept task has started for publication: {topic_name}"); + let mut last_message = None; loop { - if let Ok((mut stream, peer_addr)) = tcp_listener.accept().await { - info!("Received connection from subscriber at {peer_addr} for topic {topic_name}"); - // Read the connection header: - let connection_header = match tcpros::receive_header(&mut stream).await { - Ok(header) => header, - Err(e) => { - error!("Failed to read connection header: {e:?}"); - stream - .shutdown() - .await - .expect("Unable to shutdown tcpstream"); - continue; + let result = tokio::select! { + shutdown = shutdown_rx.recv() => { + match shutdown { + Some(_) => error!("Message should never be sent on this channel"), + None => debug!("TCP accept task has received shutdown signal for publication: {topic_name}"), } - }; + // Notify our Node that we're shutting down + nh.unregister_publisher(&topic_name).await.unwrap(); + // Exit our loop and shutdown this task + break; + } + result = tcp_listener.accept() => { + // Process the new TCP connection + result + }, + // TODO this can be optimized + // We shouldn't even call recv() if we're not latching + msg = rx.recv() => { + match msg { + Ok(msg) =>{ + // If we're latching save the message + if responding_conn_header.latching { + last_message = Some(msg); + } + }, + Err(RecvError::Lagged(num)) => { + debug!("TCP accept task for {topic_name} is lagging behind, {num} messages were skipped"); + continue; + } + Err(RecvError::Closed) => { + debug!("No more senders for the publisher channel, ending task"); + break; + } + } + continue; + } + }; + + let (mut stream, peer_addr) = match result { + Ok(result) => result, + Err(e) => { + error!("Error accepting TCP connection for topic {topic_name}: {e:?}"); + continue; + } + }; + + info!("Received connection from subscriber at {peer_addr} for topic {topic_name}"); + // Read the connection header: + let connection_header = match tcpros::receive_header(&mut stream).await { + Ok(header) => header, + Err(e) => { + error!("Failed to read connection header: {e:?}"); + stream + .shutdown() + .await + .expect("Unable to shutdown tcpstream"); + continue; + } + }; - debug!( - "Received subscribe request for {:?} with md5sum {:?}", - connection_header.topic, connection_header.md5sum - ); - // I can't find documentation for this anywhere, but when using - // `rostopic hz` with one of our publishers I discovered that the rospy code sent "*" as the md5sum - // To indicate a "generic subscription"... - // I also discovered that `rostopic echo` does not send a md5sum (even thou ros documentation says its required) - if let Some(connection_md5sum) = connection_header.md5sum { - if connection_md5sum != "*" { - if let Some(local_md5sum) = &responding_conn_header.md5sum { - // TODO(lucasw) is it ok to match any with "*"? - // if local_md5sum != "*" && connection_md5sum != *local_md5sum { - if connection_md5sum != *local_md5sum { - warn!( + debug!( + "Received subscribe request for {:?} with md5sum {:?}", + connection_header.topic, connection_header.md5sum + ); + // I can't find documentation for this anywhere, but when using + // `rostopic hz` with one of our publishers I discovered that the rospy code sent "*" as the md5sum + // To indicate a "generic subscription"... + // I also discovered that `rostopic echo` does not send a md5sum (even thou ros documentation says its required) + if let Some(connection_md5sum) = connection_header.md5sum { + if connection_md5sum != "*" { + if let Some(local_md5sum) = &responding_conn_header.md5sum { + // TODO(lucasw) is it ok to match any with "*"? + // if local_md5sum != "*" && connection_md5sum != *local_md5sum { + if connection_md5sum != *local_md5sum { + warn!( "Got subscribe request for {}, but md5sums do not match. Expected {:?}, received {:?}", topic_name, local_md5sum, connection_md5sum, ); - // Close the TCP connection - stream - .shutdown() - .await - .expect("Unable to shutdown tcpstream"); - continue; - } + // Close the TCP connection + stream + .shutdown() + .await + .expect("Unable to shutdown tcpstream"); + continue; } } } - // Write our own connection header in response - let response_header_bytes = responding_conn_header - .to_bytes(false) - .expect("Couldn't serialize connection header"); - stream - .write_all(&response_header_bytes[..]) - .await - .expect("Unable to respond on tcpstream"); - - // If we're configured to latch, send the last message to the new subscriber - if responding_conn_header.latching { - if let Some(last_message) = last_message.read().await.as_ref() { - debug!( - "Publication configured to be latching and has last_message, sending" - ); - // TODO likely a bug here... but pretty subtle - // If we disconnect here, out accept task could get blocked trying to write this message out - // until the TCP Socket errors. Resulting in this publisher being unable to connect to new - // subscribers for some number of seconds. - // This write_all should be moved into a separate task - let res = stream.write_all(last_message).await; - match res { - Ok(_) => {} - Err(e) => { - error!("Failed to send latch message to subscriber: {e:?}"); - // Note doing any handling here, TCP stream will be cleaned up during - // next regular publish in the publish task - } - } - } - } - - // let mut wlock = subscriber_streams.write().await; - // wlock.push(stream); + } + // Write our own connection header in response + let response_header_bytes = responding_conn_header + .to_bytes(false) + .expect("Couldn't serialize connection header"); + stream + .write_all(&response_header_bytes[..]) + .await + .expect("Unable to respond on tcpstream"); - // Create a new task to handle writing to the TCP stream - let rx_copy = rx.resubscribe(); - let topic_name_copy = topic_name.clone(); - tokio::spawn(async move { - Self::publish_task(rx_copy, stream, topic_name_copy).await; - }); + // Create a new task to handle writing to the TCP stream + // Note: we continue to hold on to a root "rx" in this accept task that means that we + // always keep the channel open from the receive side. + let rx_copy = rx.resubscribe(); + let topic_name_copy = topic_name.clone(); + let last_message_copy = last_message.clone(); + tokio::spawn(async move { + Self::publish_task(rx_copy, stream, topic_name_copy, last_message_copy).await; + }); - debug!( - "Added stream for topic {:?} to subscriber {}", - connection_header.topic, peer_addr - ); - } + debug!( + "Added stream for topic {:?} to subscriber {}", + connection_header.topic, peer_addr + ); } } }