Skip to content

Commit

Permalink
Fixes to prevent unnecessary protocol message sends
Browse files Browse the repository at this point in the history
  • Loading branch information
sea-bass committed Oct 13, 2023
1 parent 8d76a9b commit 84f7461
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,6 @@ async def handle_request(self, req, res):
try:
return await future
finally:
self.protocol.send(
{
"op": "service_response",
"id": request_id,
"result": False,
}
)
del self.request_futures[request_id]

def handle_response(self, request_id, res):
Expand Down Expand Up @@ -86,7 +79,14 @@ def graceful_shutdown(self):
for future_id in self.request_futures:
future = self.request_futures[future_id]
future.set_exception(RuntimeError(f"Service {self.service_name} was unadvertised"))
self.protocol.node_handle.destroy_service(self.service_handle)
self.service_handle.destroy()
self.protocol.send(
{
"op": "service_response",
"id": future,
"result": False,
}
)


class AdvertiseService(Capability):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, protocol):
protocol.node_handle.get_logger().info("Calling services in existing thread")
protocol.register_operation("call_service", self.call_service)

def call_service(self, message):
def call_service(self, message, sleep_time=0):
# Pull out the ID
cid = message.get("id", None)

Expand Down Expand Up @@ -112,7 +112,9 @@ def call_service(self, message):
e_cb = partial(self._failure, cid, service)

# Run service caller in the same thread.
ServiceCaller(trim_servicename(service), args, s_cb, e_cb, self.protocol.node_handle).run()
ServiceCaller(
trim_servicename(service), args, s_cb, e_cb, self.protocol.node_handle, sleep_time
).run()

def _success(self, cid, service, fragment_size, compression, message):
outgoing_message = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def service_response(self, message):
message_conversion.populate_instance(values, resp)
# pass along the response
service_handler.handle_response(request_id, resp)
self.protocol.send(message)
else:
self.protocol.log(
"error",
Expand Down
27 changes: 20 additions & 7 deletions rosbridge_library/src/rosbridge_library/internal/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, servicename):


class ServiceCaller(Thread):
def __init__(self, service, args, success_callback, error_callback, node_handle):
def __init__(self, service, args, success_callback, error_callback, node_handle, sleep_time=0):
"""Create a service caller for the specified service. Use start()
to start in a separate thread or run() to run in this thread.
Expand All @@ -64,7 +64,8 @@ def __init__(self, service, args, success_callback, error_callback, node_handle)
service call
error_callback -- a callback to call if an error occurs. The
callback will be passed the exception that caused the failure
node_handle -- a ROS2 node handle to call services.
node_handle -- a ROS 2 node handle to call services.
sleep_time -- if nonzero, puts sleeps between executor spins
"""
Thread.__init__(self)
self.daemon = True
Expand All @@ -73,11 +74,16 @@ def __init__(self, service, args, success_callback, error_callback, node_handle)
self.success = success_callback
self.error = error_callback
self.node_handle = node_handle
self.sleep_time = sleep_time

def run(self):
try:
# Call the service and pass the result to the success handler
self.success(call_service(self.node_handle, self.service, self.args))
self.success(
call_service(
self.node_handle, self.service, args=self.args, sleep_time=self.sleep_time
)
)
except Exception as e:
# On error, just pass the exception to the error handler
self.error(e)
Expand All @@ -99,7 +105,7 @@ def args_to_service_request_instance(service, inst, args):
populate_instance(msg, inst)


def call_service(node_handle, service, args=None):
def call_service(node_handle, service, args=None, sleep_time=0):
# Given the service name, fetch the type and class of the service,
# and a request instance
service = expand_topic_name(service, node_handle.get_name(), node_handle.get_namespace())
Expand All @@ -122,10 +128,17 @@ def call_service(node_handle, service, args=None):
client = node_handle.create_client(service_class, service)

future = client.call_async(inst)
if node_handle.executor:
node_handle.executor.spin_until_future_complete(future)
if sleep_time == 0:
if node_handle.executor:
node_handle.executor.spin_until_future_complete(future)
else:
rclpy.spin_until_future_complete(future)
else:
rclpy.spin_until_future_complete(node_handle, future)
while not future.done():
if node_handle.executor:
node_handle.executor.spin_once(timeout_sec=sleep_time)
else:
rclpy.spin_once(node_handle, timeout_sec=sleep_time)
result = future.result()

node_handle.destroy_client(client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def test_call_advertised_service(self):
)
)
self.received_message = None
Thread(target=call_service.call_service, args=(call_msg,)).start()
sleep_time = 0.01
Thread(target=call_service.call_service, args=(call_msg, sleep_time)).start()

loop_iterations = 0
while self.received_message is None:
Expand Down Expand Up @@ -182,7 +183,8 @@ def test_unadvertise_with_live_request(self):
)
)
self.received_message = None
Thread(target=call_service.call_service, args=(call_msg,)).start()
sleep_time = 0.01
Thread(target=call_service.call_service, args=(call_msg, sleep_time)).start()

loop_iterations = 0
while self.received_message is None:
Expand Down
18 changes: 14 additions & 4 deletions rosbridge_library/test/internal/services/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@ def populate_random_args(d):


class ServiceTester:
def __init__(self, name, srv_type):
def __init__(self, executor, name, srv_type):
self.name = name
self.executor = executor
self.node = Node("service_tester_" + srv_type.replace("/", "_"))
self.executor.add_node(self.node)
self.srvClass = ros_loader.get_service_class(srv_type)
self.service = self.node.create_service(self.srvClass, name, self.callback)

def __del__(self):
self.executor.remove_node(self.node)

def start(self):
req = self.srvClass.Request()
gen = c.extract_values(req)
Expand Down Expand Up @@ -85,6 +90,7 @@ def setUp(self):
self.executor.add_node(self.node)

def tearDown(self):
self.executor.remove_node(self.node)
rclpy.shutdown()

def msgs_equal(self, msg1, msg2):
Expand Down Expand Up @@ -185,7 +191,9 @@ def error():
self.assertEqual(x, y)

def test_service_tester(self):
t = ServiceTester("/test_service_tester", "rosbridge_test_msgs/TestRequestAndResponse")
t = ServiceTester(
self.executor, "/test_service_tester", "rosbridge_test_msgs/TestRequestAndResponse"
)
t.start()
time.sleep(0.2)
t.validate(self.msgs_equal)
Expand All @@ -201,7 +209,9 @@ def test_service_tester_alltypes(self):
"TestMultipleRequestFields",
"TestArrayRequest",
]:
t = ServiceTester("/test_service_tester_alltypes_" + srv, "rosbridge_test_msgs/" + srv)
t = ServiceTester(
self.executor, "/test_service_tester_alltypes_" + srv, "rosbridge_test_msgs/" + srv
)
t.start()
ts.append(t)

Expand All @@ -223,7 +233,7 @@ def test_random_service_types(self):
]
ts = []
for srv in common:
t = ServiceTester("/test_random_service_types/" + srv, srv)
t = ServiceTester(self.executor, "/test_random_service_types/" + srv, srv)
t.start()
ts.append(t)

Expand Down
11 changes: 6 additions & 5 deletions rosbridge_server/test/websocket/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import rclpy.task
from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol
from rcl_interfaces.srv import GetParameters
from rclpy.executors import SingleThreadedExecutor
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from twisted.internet import reactor
from twisted.internet.endpoints import TCP4ClientEndpoint
Expand Down Expand Up @@ -118,8 +118,8 @@ def run_websocket_test(
):
context = rclpy.Context()
rclpy.init(context=context)
executor = SingleThreadedExecutor(context=context)
node = rclpy.create_node(node_name, context=context)
executor = MultiThreadedExecutor(context=context)
node = Node(node_name, context=context)
executor.add_node(node)

async def task():
Expand All @@ -128,11 +128,12 @@ async def task():

future = executor.create_task(task)

reactor.callInThread(rclpy.spin_until_future_complete, node, future, executor)
reactor.callInThread(executor.spin_until_future_complete, future)
reactor.run(installSignalHandlers=False)

rclpy.shutdown(context=context)
executor.remove_node(node)
node.destroy_node()
rclpy.shutdown(context=context)


def sleep(node: Node, duration: float) -> Awaitable[None]:
Expand Down

0 comments on commit 84f7461

Please sign in to comment.