Skip to content

Commit

Permalink
launch the grpc server in a separate thread
Browse files Browse the repository at this point in the history
Signed-off-by: Ata Fatahi <[email protected]>
  • Loading branch information
MrAta committed Dec 18, 2024
1 parent 9fbef84 commit f072559
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ def __init__(
},
)

if server_args.grpc_port:
t = threading.Thread(
target=self._launch_grpc_server_in_loop,
name="gRPCServerThread",
daemon=True,
)
t.start()
logger.info(
f"TokenizerManager: launched gRPC server thread on port {server_args.grpc_port}"
)

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
Expand Down Expand Up @@ -571,8 +582,13 @@ def create_handle_loop(self):
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())

signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
if threading.current_thread() is threading.main_thread():
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
else:
# the thread that is used by grpc server doesn't need to handle the signal
logger.warning("Skipping add_signal_handler because not in main thread.")

loop.create_task(self.sigterm_watchdog())

async def sigterm_watchdog(self):
Expand All @@ -595,16 +611,6 @@ async def sigterm_watchdog(self):

async def handle_loop(self):
"""The event loop that handles requests"""
if self.server_args.grpc_port:
server = self._create_grpc_server(
host=self.server_args.host,
port=self.server_args.grpc_port,
)
await server.start()
logger.info(
f"gRPC server started on {self.server_args.host}:{self.server_args.grpc_port}"
)

while True:
recv_obj: Union[
BatchStrOut,
Expand Down Expand Up @@ -803,24 +809,34 @@ def detokenize_top_logprobs_tokens(
ret.append(None)
return ret

def _create_grpc_server(
self,
host: str = "0.0.0.0",
port: int = 50051,
max_workers: Optional[int] = None,
):
def _launch_grpc_server_in_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = loop.run_until_complete(self._create_grpc_server())
# Start the server before run_forever()
loop.run_until_complete(server.start())

logger.info(
f"gRPC server started, listening on {self.server_args.host}:{self.server_args.grpc_port}"
)
# Keep this loop alive so the server remains accessible
loop.run_forever()

async def _create_grpc_server(self):
# Create the server
server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
options=[
("grpc.max_send_message_length", 100 * 1024 * 1024),
("grpc.max_receive_message_length", 100 * 1024 * 1024),
],
]
)

completion_pb2_grpc.add_CompletionServiceServicer_to_server(
CompletionServicer(self.generate_request), server
)
server.add_insecure_port(f"{host}:{port}")

server.add_insecure_port(
f"{self.server_args.host}:{self.server_args.grpc_port}"
)
return server


Expand Down

0 comments on commit f072559

Please sign in to comment.