Skip to content

Commit 9f2c608

Browse files
author
Masahiro Tanaka
committed
run thread for event loop
1 parent 9cbb69f commit 9f2c608

File tree

3 files changed

+40
-23
lines changed

3 files changed

+40
-23
lines changed

mii/client.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
1010
from mii.constants import GRPC_MAX_MSG_SIZE
1111
from mii.method_table import GRPC_METHOD_TABLE
12+
from mii.event_loop import get_event_loop
1213

1314

1415
def _get_deployment_info(deployment_name):
@@ -56,7 +57,7 @@ class MIIClient():
5657
Client to send queries to a single endpoint.
5758
"""
5859
def __init__(self, task_name, host, port):
59-
self.asyncio_loop = asyncio.get_event_loop()
60+
self.asyncio_loop = get_event_loop()
6061
channel = create_channel(host, port)
6162
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
6263
self.task = get_task(task_name)
@@ -73,17 +74,22 @@ async def _request_async_response(self, request_dict, **query_kwargs):
7374
proto_response
7475
) if "unpack_response_from_proto" in conversions else proto_response
7576

76-
def query(self, request_dict, **query_kwargs):
77-
return self.asyncio_loop.run_until_complete(
77+
def query_async(self, request_dict, **query_kwargs):
78+
return asyncio.run_coroutine_threadsafe(
7879
self._request_async_response(request_dict,
79-
**query_kwargs))
80+
**query_kwargs),
81+
get_event_loop())
82+
83+
def query(self, request_dict, **query_kwargs):
84+
return self.query_async(request_dict, **query_kwargs).result()
8085

8186
async def terminate_async(self):
8287
await self.stub.Terminate(
8388
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())
8489

8590
def terminate(self):
86-
self.asyncio_loop.run_until_complete(self.terminate_async())
91+
asyncio.run_coroutine_threadsafe(self.terminate_async(),
92+
get_event_loop()).result()
8793

8894

8995
class MIITensorParallelClient():
@@ -94,7 +100,7 @@ class MIITensorParallelClient():
94100
def __init__(self, task_name, host, ports):
95101
self.task = get_task(task_name)
96102
self.clients = [MIIClient(task_name, host, port) for port in ports]
97-
self.asyncio_loop = asyncio.get_event_loop()
103+
self.asyncio_loop = get_event_loop()
98104

99105
# runs task in parallel and return the result from the first task
100106
async def _query_in_tensor_parallel(self, request_string, query_kwargs):
@@ -106,7 +112,16 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs):
106112
**query_kwargs)))
107113

108114
await responses[0]
109-
return responses[0]
115+
return responses[0].result()
116+
117+
def query_async(self, request_dict, **query_kwargs):
118+
"""Asynchronously auery a local deployment.
119+
See `query` for the arguments and the return value.
120+
"""
121+
return asyncio.run_coroutine_threadsafe(
122+
self._query_in_tensor_parallel(request_dict,
123+
query_kwargs),
124+
self.asyncio_loop)
110125

111126
def query(self, request_dict, **query_kwargs):
112127
"""Query a local deployment:
@@ -121,11 +136,7 @@ def query(self, request_dict, **query_kwargs):
121136
Returns:
122137
response: Response of the model
123138
"""
124-
response = self.asyncio_loop.run_until_complete(
125-
self._query_in_tensor_parallel(request_dict,
126-
query_kwargs))
127-
ret = response.result()
128-
return ret
139+
return self.query_async(request_dict, **query_kwargs).result()
129140

130141
def terminate(self):
131142
"""Terminates the deployment"""
@@ -135,5 +146,5 @@ def terminate(self):
135146

136147
def terminate_restful_gateway(deployment_name):
137148
_, mii_configs = _get_deployment_info(deployment_name)
138-
if mii_configs.restful_api_port > 0:
149+
if mii_configs.enable_restful_api:
139150
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")

mii/event_loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import asyncio
2+
import threading
3+
4+
global event_loop
5+
event_loop = asyncio.get_event_loop()
6+
threading.Thread(target=event_loop.run_forever, daemon=True).start()
7+
8+
9+
def get_event_loop():
10+
return event_loop

mii/grpc_related/modelresponse_server.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mii.method_table import GRPC_METHOD_TABLE
1717
from mii.client import create_channel
1818
from mii.utils import get_task
19+
from mii.event_loop import get_event_loop
1920

2021

2122
class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer):
@@ -41,6 +42,7 @@ def __init__(self, inference_pipeline):
4142
super().__init__()
4243
self.inference_pipeline = inference_pipeline
4344
self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()}
45+
self.lock = threading.Lock()
4446

4547
def _get_model_time(self, model, sum_times=False):
4648
model_times = []
@@ -71,7 +73,8 @@ def _run_inference(self, method_name, request_proto):
7173
args, kwargs = conversions["unpack_request_from_proto"](request_proto)
7274

7375
start = time.time()
74-
response = self.inference_pipeline(*args, **kwargs)
76+
with self.lock:
77+
response = self.inference_pipeline(*args, **kwargs)
7578
end = time.time()
7679

7780
model_time = self._get_model_time(self.inference_pipeline.model,
@@ -133,7 +136,7 @@ def __init__(self, host, ports):
133136
stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
134137
self.stubs.append(stub)
135138

136-
self.asyncio_loop = asyncio.get_event_loop()
139+
self.asyncio_loop = get_event_loop()
137140

138141
async def _invoke_async(self, method_name, proto_request):
139142
responses = []
@@ -153,7 +156,7 @@ def invoke(self, method_name, proto_request):
153156
class LoadBalancingInterceptor(grpc.ServerInterceptor):
154157
def __init__(self, task_name, replica_configs):
155158
super().__init__()
156-
self.asyncio_loop = asyncio.get_event_loop()
159+
self.asyncio_loop = get_event_loop()
157160

158161
self.stubs = [
159162
ParallelStubInvoker(replica.hostname,
@@ -163,13 +166,6 @@ def __init__(self, task_name, replica_configs):
163166
self.counter = AtomicCounter()
164167
self.task = get_task(task_name)
165168

166-
# Start the asyncio loop in a separate thread
167-
def run_asyncio_loop(loop):
168-
asyncio.set_event_loop(loop)
169-
loop.run_forever()
170-
171-
threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start()
172-
173169
def choose_stub(self, call_count):
174170
return self.stubs[call_count % len(self.stubs)]
175171

0 commit comments

Comments
 (0)