diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5d3b8c9b9a..20da17cde1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -176,6 +176,17 @@ def __init__( }, ) + if server_args.grpc_port: + t = threading.Thread( + target=self._launch_grpc_server_in_loop, + name="gRPCServerThread", + daemon=True, + ) + t.start() + logger.info( + f"TokenizerManager: launched gRPC server thread on port {server_args.grpc_port}" + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -571,8 +582,13 @@ def create_handle_loop(self): loop = asyncio.get_event_loop() loop.create_task(self.handle_loop()) - signal_handler = SignalHandler(self) - loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) + if threading.current_thread() is threading.main_thread(): + signal_handler = SignalHandler(self) + loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) + else: + # the thread that is used by grpc server doesn't need to handle the signal + logger.warning("Skipping add_signal_handler because not in main thread.") + loop.create_task(self.sigterm_watchdog()) async def sigterm_watchdog(self): @@ -595,16 +611,6 @@ async def sigterm_watchdog(self): async def handle_loop(self): """The event loop that handles requests""" - if self.server_args.grpc_port: - server = self._create_grpc_server( - host=self.server_args.host, - port=self.server_args.grpc_port, - ) - await server.start() - logger.info( - f"gRPC server started on {self.server_args.host}:{self.server_args.grpc_port}" - ) - while True: recv_obj: Union[ BatchStrOut, @@ -803,24 +809,34 @@ def detokenize_top_logprobs_tokens( ret.append(None) return ret - def _create_grpc_server( - self, - host: str = "0.0.0.0", - port: int = 50051, - max_workers: Optional[int] = None, - ): + def _launch_grpc_server_in_loop(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + server = loop.run_until_complete(self._create_grpc_server()) + # Start the server before run_forever() + loop.run_until_complete(server.start()) + + logger.info( + f"gRPC server started, listening on {self.server_args.host}:{self.server_args.grpc_port}" + ) + # Keep this loop alive so the server remains accessible + loop.run_forever() + + async def _create_grpc_server(self): + # Create the server server = grpc.aio.server( - futures.ThreadPoolExecutor(max_workers=max_workers), options=[ ("grpc.max_send_message_length", 100 * 1024 * 1024), ("grpc.max_receive_message_length", 100 * 1024 * 1024), - ], + ] ) - completion_pb2_grpc.add_CompletionServiceServicer_to_server( CompletionServicer(self.generate_request), server ) - server.add_insecure_port(f"{host}:{port}") + + server.add_insecure_port( + f"{self.server_args.host}:{self.server_args.grpc_port}" + ) return server