diff --git a/roslibrust/Cargo.toml b/roslibrust/Cargo.toml index 5281302..51183c1 100644 --- a/roslibrust/Cargo.toml +++ b/roslibrust/Cargo.toml @@ -63,7 +63,7 @@ rosapi = ["serde-big-array"] # Intended for use with tests, includes tests that rely on a locally running rosbridge running_bridge = [] # For use with integration tests, indicating we are testing integration with a ros1 bridge -ros1_test = ["running_bridge"] +ros1_test = ["running_bridge", "ros1"] # For use with integration tests, indicates we are testing integration with a ros2 bridge ros2_test = ["running_bridge"] # Provides access to experimental abstract trait topic_provider diff --git a/roslibrust/src/ros1/node/actor.rs b/roslibrust/src/ros1/node/actor.rs index 558ea62..34f2a08 100644 --- a/roslibrust/src/ros1/node/actor.rs +++ b/roslibrust/src/ros1/node/actor.rs @@ -11,7 +11,7 @@ use crate::{ RosLibRustError, }; use abort_on_drop::ChildTask; -use log::warn; +use log::*; use roslibrust_codegen::{RosMessageType, RosServiceType}; use std::{collections::HashMap, io, net::Ipv4Addr, sync::Arc}; use tokio::sync::{broadcast, mpsc, oneshot}; @@ -82,10 +82,13 @@ pub enum NodeMsg { }, RequestTopic { reply: oneshot::Sender>, - caller_id: String, topic: String, protocols: Vec, }, + UnregisterPublisher { + reply: oneshot::Sender>, + topic: String, + }, } #[derive(Clone)] @@ -99,7 +102,7 @@ pub(crate) struct NodeServerHandle { impl NodeServerHandle { /// Get the URI of the master node. - pub async fn get_master_uri(&self) -> Result { + pub(crate) async fn get_master_uri(&self) -> Result { let (sender, receiver) = oneshot::channel(); self.node_server_sender .send(NodeMsg::GetMasterUri { reply: sender })?; @@ -107,7 +110,7 @@ impl NodeServerHandle { } /// Get the URI of the client node. - pub async fn get_client_uri(&self) -> Result { + pub(crate) async fn get_client_uri(&self) -> Result { let (sender, receiver) = oneshot::channel(); self.node_server_sender .send(NodeMsg::GetClientUri { reply: sender })?; @@ -116,7 +119,7 @@ impl NodeServerHandle { /// Gets the list of topics the node is currently subscribed to. /// Returns a tuple of (Topic Name, Topic Type) e.g. ("/rosout", "rosgraph_msgs/Log"). - pub async fn get_subscriptions(&self) -> Result, NodeError> { + pub(crate) async fn get_subscriptions(&self) -> Result, NodeError> { let (sender, receiver) = oneshot::channel(); self.node_server_sender .send(NodeMsg::GetSubscriptions { reply: sender })?; @@ -125,7 +128,7 @@ impl NodeServerHandle { /// Gets the list of topic the node is currently publishing to. /// Returns a tuple of (Topic Name, Topic Type) e.g. ("/rosout", "rosgraph_msgs/Log"). - pub async fn get_publications(&self) -> Result, NodeError> { + pub(crate) async fn get_publications(&self) -> Result, NodeError> { let (sender, receiver) = oneshot::channel(); self.node_server_sender .send(NodeMsg::GetPublications { reply: sender })?; @@ -134,7 +137,7 @@ impl NodeServerHandle { /// Updates the list of know publishers for a given topic /// This is used to know who to reach out to for updates - pub fn set_peer_publishers( + pub(crate) fn set_peer_publishers( &self, topic: String, publishers: Vec, @@ -148,14 +151,14 @@ impl NodeServerHandle { /// This will stop all ROS functionality and poison all NodeHandles connected /// to the underlying node server. // TODO this function should probably be pub(crate) and not pub? - pub fn shutdown(&self) -> Result<(), NodeError> { + pub(crate) fn shutdown(&self) -> Result<(), NodeError> { self.node_server_sender.send(NodeMsg::Shutdown)?; Ok(()) } /// Registers a publisher with the underlying node server /// Returns a channel that the raw bytes of a publish can be shoved into to queue the publish - pub async fn register_publisher( + pub(crate) async fn register_publisher( &self, topic: &str, queue_size: usize, @@ -177,10 +180,23 @@ impl NodeServerHandle { })?) } + pub(crate) async fn unregister_publisher(&self, topic: &str) -> Result<(), NodeError> { + let (sender, receiver) = oneshot::channel(); + self.node_server_sender.send(NodeMsg::UnregisterPublisher { + reply: sender, + topic: topic.to_owned(), + })?; + let rx = receiver.await?; + rx.map_err(|err| { + warn!("Failure while unregistering publisher: {err:?}"); + NodeError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted)) + }) + } + /// Registers a service client with the underlying node server /// This returns a channel that can be used for making service calls /// service calls will be queued in the channel and resolved when able. - pub async fn register_service_client( + pub(crate) async fn register_service_client( &self, service_name: &Name, ) -> Result, NodeError> { @@ -209,7 +225,7 @@ impl NodeServerHandle { Ok(ServiceClient::new(service_name, sender, link)) } - pub async fn register_service_server( + pub(crate) async fn register_service_server( &self, service_name: &Name, server: F, @@ -256,7 +272,7 @@ impl NodeServerHandle { /// Called to remove a service server /// Delegates to the NodeServer via channel - pub async fn unadvertise_service(&self, service_name: &str) -> Result<(), NodeError> { + pub(crate) async fn unadvertise_service(&self, service_name: &str) -> Result<(), NodeError> { let (tx, rx) = oneshot::channel(); log::debug!("Queuing unregister service server command for: {service_name:?}"); self.node_server_sender @@ -274,7 +290,7 @@ impl NodeServerHandle { /// If this is the first time the given topic has been subscribed to (by this node) /// rosmaster will be informed. /// Otherwise, a new rx handle will simply be returned to the existing channel. - pub async fn register_subscriber( + pub(crate) async fn register_subscriber( &self, topic: &str, queue_size: usize, @@ -302,15 +318,13 @@ impl NodeServerHandle { // to marshal the response. // Users can call this function, but it really doesn't serve much of a purpose outside ROS Pub/Sub communication // negotiation - pub async fn request_topic( + pub(crate) async fn request_topic( &self, - caller_id: &str, topic: &str, protocols: &[String], ) -> Result { let (sender, receiver) = oneshot::channel(); self.node_server_sender.send(NodeMsg::RequestTopic { - caller_id: caller_id.to_owned(), topic: topic.to_owned(), protocols: protocols.into(), reply: sender, @@ -346,10 +360,12 @@ pub(crate) struct Node { // service_clients: HashMap, // Map of topic names to service server handles for each topic service_servers: HashMap, - // TODO need signal to shutdown xmlrpc server when node is dropped + // TODO MAJOR: need signal to shutdown xmlrpc server when node is dropped host_addr: Ipv4Addr, hostname: String, node_name: Name, + // Store a handle to ourself so that we can pass it out later + node_handle: NodeServerHandle, } impl Node { @@ -371,6 +387,10 @@ impl Node { let rosmaster_client = MasterClient::new(master_uri, client_uri, node_name.to_string()).await?; + let weak_handle = NodeServerHandle { + node_server_sender: node_sender.clone(), + _node_task: None, + }; let mut node = Self { client: rosmaster_client, _xmlrpc_server: xmlrpc_server, @@ -381,6 +401,7 @@ impl Node { host_addr: addr, hostname: hostname.to_owned(), node_name: node_name.to_owned(), + node_handle: weak_handle, }; let t = Arc::new( @@ -476,6 +497,13 @@ impl Node { } .expect("Failed to reply on oneshot"); } + NodeMsg::UnregisterPublisher { reply, topic } => { + let _ = reply.send( + self.unregister_publisher(&topic) + .await + .map_err(|err| err.to_string()), + ); + } NodeMsg::RegisterSubscriber { reply, topic, @@ -543,7 +571,6 @@ impl Node { reply, topic, protocols, - .. } => { // TODO: Should move the actual implementation similar to RegisterPublisher if protocols @@ -625,8 +652,16 @@ impl Node { self.publishers.iter().find_map(|(key, value)| { if key.as_str() == &topic { if value.topic_type() == topic_type { - Some(Ok(value.get_sender())) + if let Some(sender) = value.get_sender() { + return Some(Ok(sender)); + }else{ + // Edge case here + // The channel for the publication is closed, but publication hasn't been cleaned up yet + None + } } else { + warn!("Attempted to register publisher with different topic type than existing publisher: existing_type={}, new_type={}", value.topic_type(), topic_type); + // TODO MAJOR: this is a terrible error type to return... Some(Err(NodeError::IoError(std::io::Error::from( std::io::ErrorKind::AddrInUse, )))) @@ -636,30 +671,48 @@ impl Node { } }) }; + // If we found an existing publication return the handle to it if let Some(handle) = existing_entry { - Ok(handle?) - } else { - // Otherwise create a new Publication - let channel = Publication::new( - &self.node_name, - latching, - &topic, - self.host_addr, - queue_size, - &msg_definition, - &md5sum, - topic_type, - ) - .await - .map_err(|err| { - log::error!("Failed to create publishing channel: {err:?}"); - err - })?; - let handle = channel.get_sender(); - self.publishers.insert(topic.clone(), channel); - let _current_subscribers = self.client.register_publisher(&topic, topic_type).await?; - Ok(handle) + return Ok(handle?); + } + + // Otherwise create a new Publication + let (channel, sender) = Publication::new( + &self.node_name, + latching, + &topic, + self.host_addr, + queue_size, + &msg_definition, + &md5sum, + topic_type, + self.node_handle.clone(), + ) + .await + .map_err(|err| { + log::error!("Failed to create publishing channel: {err:?}"); + err + })?; + self.publishers.insert(topic.clone(), channel); + Ok(sender) + } + + async fn unregister_publisher(&mut self, topic: &str) -> Result<(), NodeError> { + // Tell ros master we are no longer publishing this topic + let err1 = self.client.unregister_publisher(topic).await; + // Remove the publication from our internal state + let err2 = self.publishers.remove(topic); + if err1.is_err() || err2.is_none() { + error!( + "Failure unregistering publisher: {err1:?}, {}", + err2.is_none() + ); + // MAJOR TODO: this is a terrible error type to return... + return Err(NodeError::IoError(std::io::Error::from( + std::io::ErrorKind::AddrInUse, + ))); } + Ok(()) } /// Checks the internal state of the NodeServer to see if it has a service client registered for this service already diff --git a/roslibrust/src/ros1/node/handle.rs b/roslibrust/src/ros1/node/handle.rs index f695778..05bd5ab 100644 --- a/roslibrust/src/ros1/node/handle.rs +++ b/roslibrust/src/ros1/node/handle.rs @@ -116,6 +116,7 @@ impl NodeHandle { Ok(ServiceServer::new(service_name, self.clone())) } + // TODO Major: This should probably be moved to NodeServerHandle? /// Not intended to be called manually /// Stops hosting the specified server. /// This is automatically called when dropping the ServiceServer returned by [advertise_service] diff --git a/roslibrust/src/ros1/node/mod.rs b/roslibrust/src/ros1/node/mod.rs index bcc40ca..7d0a569 100644 --- a/roslibrust/src/ros1/node/mod.rs +++ b/roslibrust/src/ros1/node/mod.rs @@ -7,7 +7,7 @@ use std::{ net::{IpAddr, Ipv4Addr}, }; -mod actor; +pub(crate) mod actor; mod handle; mod xmlrpc; use actor::*; diff --git a/roslibrust/src/ros1/node/xmlrpc.rs b/roslibrust/src/ros1/node/xmlrpc.rs index 99a11d8..2a58454 100644 --- a/roslibrust/src/ros1/node/xmlrpc.rs +++ b/roslibrust/src/ros1/node/xmlrpc.rs @@ -197,7 +197,7 @@ impl XmlRpcServer { let protocols = protocols.iter().flatten().cloned().collect::>(); debug!("Request for topic {topic} from {caller_id} via protocols {protocols:?}"); let params = node_server - .request_topic(&caller_id, &topic, &protocols) + .request_topic(&topic, &protocols) .await .map_err(|e| { Self::make_error_response( diff --git a/roslibrust/src/ros1/publisher.rs b/roslibrust/src/ros1/publisher.rs index bc04076..c455339 100644 --- a/roslibrust/src/ros1/publisher.rs +++ b/roslibrust/src/ros1/publisher.rs @@ -1,5 +1,6 @@ use crate::ros1::{ names::Name, + node::actor::NodeServerHandle, tcpros::{self, ConnectionHeader}, }; use abort_on_drop::ChildTask; @@ -52,11 +53,13 @@ pub(crate) struct Publication { listener_port: u16, _tcp_accept_task: ChildTask<()>, _publish_task: ChildTask<()>, - publish_sender: mpsc::Sender>, + publish_sender: mpsc::WeakSender>, } impl Publication { /// Spawns a new publication and sets up all tasks to run it + /// Returns a handle to the publication and a mpsc::Sender to send messages to be published + /// Dropping the Sender will (eventually) result in the publication being dropped and all tasks being canceled pub(crate) async fn new( node_name: &Name, latching: bool, @@ -66,7 +69,8 @@ impl Publication { msg_definition: &str, md5sum: &str, topic_type: &str, - ) -> Result { + node_handle: NodeServerHandle, + ) -> Result<(Self, mpsc::Sender>), std::io::Error> { // Get a socket for receiving connections on let host_addr = SocketAddr::from((host_addr, 0)); let tcp_listener = tokio::net::TcpListener::bind(host_addr).await?; @@ -96,13 +100,13 @@ impl Publication { // Create the task that will accept new TCP connections let subscriber_streams_copy = subscriber_streams.clone(); - let topic_name = topic_name.to_owned(); let last_message_copy = last_message.clone(); + let topic_name_copy = topic_name.to_owned(); let tcp_accept_handle = tokio::spawn(async move { Self::tcp_accept_task( tcp_listener, subscriber_streams_copy, - topic_name, + topic_name_copy, responding_conn_header, last_message_copy, ) @@ -110,21 +114,35 @@ impl Publication { }); // Create the task that will handle publishing messages to all streams + let topic_name_copy = topic_name.to_string(); let publish_task = tokio::spawn(async move { - Self::publish_task(receiver, subscriber_streams, last_message).await + Self::publish_task( + receiver, + subscriber_streams, + last_message, + node_handle, + topic_name_copy, + ) + .await }); - Ok(Self { - topic_type: topic_type.to_owned(), - _tcp_accept_task: tcp_accept_handle.into(), - listener_port, - publish_sender: sender, - _publish_task: publish_task.into(), - }) + let sender_copy = sender.clone(); + Ok(( + Self { + topic_type: topic_type.to_owned(), + _tcp_accept_task: tcp_accept_handle.into(), + listener_port, + publish_sender: sender.downgrade(), + _publish_task: publish_task.into(), + }, + sender_copy, + )) } - pub(crate) fn get_sender(&self) -> mpsc::Sender> { - self.publish_sender.clone() + // Note: this returns Option<> due to a timing edge case + // There can be a delay between when the last sender is dropped and when the publication is dropped + pub(crate) fn get_sender(&self) -> Option>> { + self.publish_sender.clone().upgrade() } pub(crate) fn port(&self) -> u16 { @@ -143,6 +161,8 @@ impl Publication { mut rx: mpsc::Receiver>, subscriber_streams: Arc>>, last_message: Arc>>>, + node_handle: NodeServerHandle, + topic: String, ) { loop { match rx.recv().await { @@ -169,7 +189,12 @@ impl Publication { *last_message.write().await = Some(msg_to_publish); } None => { - log::debug!("No more senders for the publisher channel, exiting..."); + log::debug!( + "No more senders for the publisher channel, triggering publication cleanup" + ); + // All senders dropped, so we want to cleanup the Publication off of the node + // Tell the node server to dispose of this publication and unadvertise it + let _ = node_handle.unregister_publisher(&topic).await; break; } } @@ -273,6 +298,12 @@ impl Publication { } } +impl Drop for Publication { + fn drop(&mut self) { + log::debug!("Dropping publication for topic {}", self.topic_type); + } +} + #[derive(thiserror::Error, Debug)] pub enum PublisherError { /// Serialize Error from `serde_rosmsg::Error` (stored as String because of dyn Error) diff --git a/roslibrust/tests/ros1_native_integration_tests.rs b/roslibrust/tests/ros1_native_integration_tests.rs index 4973783..df1b298 100644 --- a/roslibrust/tests/ros1_native_integration_tests.rs +++ b/roslibrust/tests/ros1_native_integration_tests.rs @@ -366,4 +366,42 @@ mod tests { debug!("Got call: {call:?}"); assert!(call.is_err()); } + + #[test_log::test(tokio::test)] + async fn test_dropping_publisher_unadvertises() { + let nh = NodeHandle::new("http://localhost:11311", "/test_dropping_publisher") + .await + .unwrap(); + let publisher = nh + .advertise::("/test_dropping_publisher", 1, false) + .await + .unwrap(); + + let master_client = roslibrust::ros1::MasterClient::new( + "http://localhost:11311", + "NAN", + "/test_dropping_publisher_mc", + ) + .await + .unwrap(); + + let before = master_client.get_published_topics("").await.unwrap(); + assert!(before.contains(&( + "/test_dropping_publisher".to_string(), + "std_msgs/Header".to_string() + ))); + + debug!("Start manual drop"); + // Drop the publisher + std::mem::drop(publisher); + debug!("End manual drop"); + // Give a little time for drop to process + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + // Confirm no longer advertised + let after = master_client.get_published_topics("").await.unwrap(); + assert!(!after.contains(&( + "/test_dropping_publisher".to_string(), + "std_msgs/Header".to_string() + ))); + } }