99from mii .grpc_related .proto import modelresponse_pb2 , modelresponse_pb2_grpc
1010from mii .constants import GRPC_MAX_MSG_SIZE
1111from mii .method_table import GRPC_METHOD_TABLE
12+ from mii .event_loop import get_event_loop
1213
1314
1415def _get_deployment_info (deployment_name ):
@@ -56,7 +57,7 @@ class MIIClient():
5657 Client to send queries to a single endpoint.
5758 """
5859 def __init__ (self , task_name , host , port ):
59- self .asyncio_loop = asyncio . get_event_loop ()
60+ self .asyncio_loop = get_event_loop ()
6061 channel = create_channel (host , port )
6162 self .stub = modelresponse_pb2_grpc .ModelResponseStub (channel )
6263 self .task = get_task (task_name )
@@ -73,17 +74,22 @@ async def _request_async_response(self, request_dict, **query_kwargs):
7374 proto_response
7475 ) if "unpack_response_from_proto" in conversions else proto_response
7576
76- def query (self , request_dict , ** query_kwargs ):
77- return self . asyncio_loop . run_until_complete (
77+ def query_async (self , request_dict , ** query_kwargs ):
78+ return asyncio . run_coroutine_threadsafe (
7879 self ._request_async_response (request_dict ,
79- ** query_kwargs ))
80+ ** query_kwargs ),
81+ get_event_loop ())
82+
83+ def query (self , request_dict , ** query_kwargs ):
84+ return self .query_async (request_dict , ** query_kwargs ).result ()
8085
8186 async def terminate_async (self ):
8287 await self .stub .Terminate (
8388 modelresponse_pb2 .google_dot_protobuf_dot_empty__pb2 .Empty ())
8489
8590 def terminate (self ):
86- self .asyncio_loop .run_until_complete (self .terminate_async ())
91+ asyncio .run_coroutine_threadsafe (self .terminate_async (),
92+ get_event_loop ()).result ()
8793
8894
8995class MIITensorParallelClient ():
@@ -94,7 +100,7 @@ class MIITensorParallelClient():
94100 def __init__ (self , task_name , host , ports ):
95101 self .task = get_task (task_name )
96102 self .clients = [MIIClient (task_name , host , port ) for port in ports ]
97- self .asyncio_loop = asyncio . get_event_loop ()
103+ self .asyncio_loop = get_event_loop ()
98104
99105 # runs task in parallel and return the result from the first task
100106 async def _query_in_tensor_parallel (self , request_string , query_kwargs ):
@@ -106,7 +112,16 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs):
106112 ** query_kwargs )))
107113
108114 await responses [0 ]
109- return responses [0 ]
115+ return responses [0 ].result ()
116+
117+ def query_async (self , request_dict , ** query_kwargs ):
118+ """Asynchronously auery a local deployment.
119+ See `query` for the arguments and the return value.
120+ """
121+ return asyncio .run_coroutine_threadsafe (
122+ self ._query_in_tensor_parallel (request_dict ,
123+ query_kwargs ),
124+ self .asyncio_loop )
110125
111126 def query (self , request_dict , ** query_kwargs ):
112127 """Query a local deployment:
@@ -121,11 +136,7 @@ def query(self, request_dict, **query_kwargs):
121136 Returns:
122137 response: Response of the model
123138 """
124- response = self .asyncio_loop .run_until_complete (
125- self ._query_in_tensor_parallel (request_dict ,
126- query_kwargs ))
127- ret = response .result ()
128- return ret
139+ return self .query_async (request_dict , ** query_kwargs ).result ()
129140
130141 def terminate (self ):
131142 """Terminates the deployment"""
@@ -135,5 +146,5 @@ def terminate(self):
135146
136147def terminate_restful_gateway (deployment_name ):
137148 _ , mii_configs = _get_deployment_info (deployment_name )
138- if mii_configs .restful_api_port > 0 :
149+ if mii_configs .enable_restful_api :
139150 requests .get (f"http://localhost:{ mii_configs .restful_api_port } /terminate" )
0 commit comments