From 5e28e316c8d88fa099389e1930f76fd22327b0b8 Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:02:04 -0400 Subject: [PATCH] fix aio exception handling (#2084) --- nvflare/fuel/f3/drivers/aio_context.py | 9 ++++ nvflare/fuel/f3/drivers/aio_grpc_driver.py | 59 +++++++++------------ nvflare/fuel/f3/drivers/grpc_driver.py | 60 +++++++++++----------- 3 files changed, 64 insertions(+), 64 deletions(-) diff --git a/nvflare/fuel/f3/drivers/aio_context.py b/nvflare/fuel/f3/drivers/aio_context.py index b159314b3b..6f91cf41d6 100644 --- a/nvflare/fuel/f3/drivers/aio_context.py +++ b/nvflare/fuel/f3/drivers/aio_context.py @@ -42,10 +42,19 @@ def get_event_loop(self): return self.loop + def _handle_exception(self, loop, context): + try: + msg = context.get("exception", context["message"]) + self.logger.debug(f"AIO Exception: {msg}") + except Exception as ex: + # ignore exception in the exception handler + self.logger.debug(f"exception in aio exception handler: {ex}") + def run_aio_loop(self): self.logger.debug(f"{self.name}: started AioContext in thread {threading.current_thread().name}") # self.loop = asyncio.get_event_loop() self.loop = asyncio.new_event_loop() + self.loop.set_exception_handler(self._handle_exception) asyncio.set_event_loop(self.loop) self.logger.debug(f"{self.name}: got loop: {id(self.loop)}") self.ready.set() diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 74c894be13..89c289af9b 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -69,8 +69,10 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di conf = CommConfigurator() if conf.get_bool_var("simulate_unstable_network", default=False): - self.disconn = threading.Thread(target=self._disconnect, daemon=True) - self.disconn.start() + if context: + # only server side + self.disconn = threading.Thread(target=self._disconnect, daemon=True) + self.disconn.start() def _disconnect(self): t = random.randint(10, 60) @@ -88,9 +90,11 @@ def close(self): if self.context: self.aio_ctx.run_coro(self.context.abort(grpc.StatusCode.CANCELLED, "service closed")) self.context = None + self.logger.info("Closed GRPC context") if self.channel: self.aio_ctx.run_coro(self.channel.close()) self.channel = None + self.logger.info("Closed GRPC Channel") def send_frame(self, frame: BytesAlike): try: @@ -298,57 +302,42 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co address = get_address(params) self.logger.debug(f"CLIENT: trying to connect {address}") + connection = None try: secure = ssl_required(params) if secure: - grpc_channel = grpc.aio.secure_channel( + channel = grpc.aio.secure_channel( address, options=self.options, credentials=get_grpc_client_credentials(params) ) self.logger.info(f"created secure channel at {address}") else: - grpc_channel = grpc.aio.insecure_channel(address, options=self.options) + channel = grpc.aio.insecure_channel(address, options=self.options) self.logger.info(f"created insecure channel at {address}") + stub = StreamerStub(channel) - async with grpc_channel as channel: - self.logger.debug(f"CLIENT: connected to {address}") - stub = StreamerStub(channel) - conn_props = {DriverParams.PEER_ADDR.value: address} - - if secure: - conn_props[DriverParams.PEER_CN.value] = "N/A" + self.logger.debug(f"CLIENT: connected to {address}") + conn_props = {DriverParams.PEER_ADDR.value: address} - connection = AioStreamSession( - aio_ctx=aio_ctx, connector=connector, conn_props=conn_props, channel=channel - ) + if secure: + conn_props[DriverParams.PEER_CN.value] = "N/A" - try: - self.logger.debug(f"CLIENT: start streaming on connection {connection}") - msg_iter = stub.Stream(connection.generate_output()) - conn_ctx.conn = connection - await connection.read_loop(msg_iter) - except asyncio.CancelledError as error: - self.logger.debug(f"CLIENT: RPC cancelled: {error}") - except Exception as ex: - if self.closing: - self.logger.debug( - f"Connection {connection} closed by {type(ex)}: {secure_format_exception(ex)}" - ) - else: - self.logger.debug( - f"Connection {connection} client read exception {type(ex)}: {secure_format_exception(ex)}" - ) - self.logger.debug(secure_format_traceback()) + connection = AioStreamSession(aio_ctx=aio_ctx, connector=connector, conn_props=conn_props, channel=channel) - with connection.lock: - connection.channel = None - connection.close() + self.logger.debug(f"CLIENT: start streaming on connection {connection}") + msg_iter = stub.Stream(connection.generate_output()) + conn_ctx.conn = connection + await connection.read_loop(msg_iter) except asyncio.CancelledError: self.logger.debug("CLIENT: RPC cancelled") + except grpc.FutureCancelledError: + self.logger.info("CLIENT: Future cancelled") except Exception as ex: conn_ctx.error = f"connection {connection} error: {type(ex)}: {secure_format_exception(ex)}" self.logger.debug(conn_ctx.error) self.logger.debug(secure_format_traceback()) - + finally: + if connection: + connection.close() conn_ctx.waiter.set() def connect(self, connector: ConnectorInfo): diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index a29399c253..f87fa89cda 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -68,13 +68,18 @@ def close(self): if self.context: try: self.context.abort(grpc.StatusCode.CANCELLED, "service closed") - except: + except Exception as ex: # ignore any exception when aborting - pass + self.logger.debug(f"exception aborting GRPC context: {secure_format_exception(ex)}") self.context = None + self.logger.info("Closed GRPC context") if self.channel: - self.channel.close() + try: + self.channel.close() + except Exception as ex: + self.logger.debug(f"exception closing GRPC channel: {secure_format_exception(ex)}") self.channel = None + self.logger.info("Closed GRPC Channel") def send_frame(self, frame: Union[bytes, bytearray, memoryview]): try: @@ -233,39 +238,36 @@ def connect(self, connector: ConnectorInfo): params = connector.params address = get_address(params) conn_props = {DriverParams.PEER_ADDR.value: address} + connection = None + try: + secure = ssl_required(params) + if secure: + self.logger.debug("CLIENT: creating secure channel") + channel = grpc.secure_channel( + address, options=self.options, credentials=get_grpc_client_credentials(params) + ) + self.logger.info(f"created secure channel at {address}") + else: + self.logger.info("CLIENT: creating insecure channel") + channel = grpc.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") - secure = ssl_required(params) - if secure: - self.logger.debug("CLIENT: creating secure channel") - channel = grpc.secure_channel( - address, options=self.options, credentials=get_grpc_client_credentials(params) - ) - self.logger.info(f"created secure channel at {address}") - else: - self.logger.info("CLIENT: creating insecure channel") - channel = grpc.insecure_channel(address, options=self.options) - self.logger.info(f"created insecure channel at {address}") - - with channel: stub = StreamerStub(channel) self.logger.debug("CLIENT: got stub") oq = QQ() connection = StreamConnection(oq, connector, conn_props, "CLIENT", channel=channel) self.add_connection(connection) self.logger.debug("CLIENT: added connection") - try: - received = stub.Stream(connection.generate_output()) - connection.read_loop(received) - - except BaseException as ex: - self.logger.info(f"CLIENT: connection done: {type(ex)}") - - with connection.lock: - # when we get here the channel is already closed - # set connection.channel to None to prevent closing channel again in connection.close(). - connection.channel = None - connection.close() - self.close_connection(connection) + received = stub.Stream(connection.generate_output()) + connection.read_loop(received) + except grpc.FutureCancelledError: + self.logger.debug("RPC Cancelled") + except Exception as ex: + self.logger.error(f"connection {connection} error: {type(ex)}: {secure_format_exception(ex)}") + finally: + if connection: + connection.close() + self.close_connection(connection) self.logger.info(f"CLIENT: finished connection {connection}") @staticmethod