Skip to content

Commit 0a33cde

Browse files
alaeddine-13jina-botJoan Fontanals
authored
fix: make gateway load balancer should stream results (#6024)
Co-authored-by: Jina Dev Bot <[email protected]> Co-authored-by: Joan Fontanals <[email protected]>
1 parent 36e67f5 commit 0a33cde

File tree

3 files changed

+117
-24
lines changed

3 files changed

+117
-24
lines changed

jina/serve/runtimes/gateway/request_handling.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
if TYPE_CHECKING: # pragma: no cover
1313
from types import SimpleNamespace
14+
1415
import grpc
1516

1617
from jina.logging.logger import JinaLogger
@@ -185,12 +186,23 @@ async def _load_balance(self, request):
185186
async with aiohttp.ClientSession() as session:
186187
if request.method == 'GET':
187188
async with session.get(target_url) as response:
188-
content = await response.read()
189-
return web.Response(
190-
body=content,
189+
# Create a StreamResponse with the same headers and status as the target response
190+
stream_response = web.StreamResponse(
191191
status=response.status,
192-
content_type=response.content_type,
192+
headers=response.headers,
193193
)
194+
195+
# Prepare the response to send headers
196+
await stream_response.prepare(request)
197+
198+
# Stream the response from the target server to the client
199+
async for chunk in response.content.iter_any():
200+
await stream_response.write(chunk)
201+
202+
# Close the stream response once all chunks are sent
203+
await stream_response.write_eof()
204+
return stream_response
205+
194206
elif request.method == 'POST':
195207
d = await request.read()
196208
import json
@@ -282,7 +294,7 @@ async def stream(
282294
yield resp
283295

284296
async def stream_doc(
285-
self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext'
297+
self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext'
286298
) -> SingleDocumentRequest:
287299
"""
288300
Process the received requests and return the result as a new request
@@ -293,7 +305,7 @@ async def stream_doc(
293305
"""
294306
self.logger.debug('recv a stream_doc request')
295307
async for result in self.streamer.rpc_stream_doc(
296-
request=request,
308+
request=request,
297309
):
298310
yield result
299311

@@ -317,6 +329,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
317329
:returns: the response request
318330
"""
319331
from google.protobuf import json_format
332+
320333
self.logger.debug('got an endpoint discovery request')
321334
response = jina_pb2.EndpointsProto()
322335
await self.streamer._get_endpoints_input_output_models(is_cancel=None)
@@ -332,7 +345,9 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
332345
response.endpoints.extend(schema_maps.keys())
333346
json_format.ParseDict(schema_maps, response.schemas)
334347
else:
335-
endpoints = await self.streamer.topology_graph._get_all_endpoints(self.streamer._connection_pool, retry_forever=True, is_cancel=None)
348+
endpoints = await self.streamer.topology_graph._get_all_endpoints(
349+
self.streamer._connection_pool, retry_forever=True, is_cancel=None
350+
)
336351
response.endpoints.extend(list(endpoints))
337352
return response
338353

tests/integration/docarray_v2/test_streaming.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import asyncio
2+
import time
13
from typing import AsyncGenerator, Generator, Optional
24

35
import pytest
6+
from docarray import BaseDoc, DocList
47

5-
from jina import Client, Executor, requests, Flow, Deployment
6-
from docarray import DocList, BaseDoc
8+
from jina import Client, Deployment, Executor, Flow, requests
79
from jina.helper import random_port
810

911

@@ -67,20 +69,21 @@ async def test_streaming_deployment(protocol, include_gateway):
6769
assert doc.text == f'hello world {i}'
6870
i += 1
6971

72+
7073
@pytest.mark.asyncio
7174
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
7275
async def test_streaming_flow(protocol):
7376
port = random_port()
7477

7578
with Flow(protocol=protocol, port=port, cors=True).add(
76-
uses=MyExecutor,
79+
uses=MyExecutor,
7780
):
7881
client = Client(port=port, protocol=protocol, asyncio=True)
7982
i = 10
8083
async for doc in client.stream_doc(
81-
on='/hello',
82-
inputs=MyDocument(text='hello world', number=i),
83-
return_type=MyDocument,
84+
on='/hello',
85+
inputs=MyDocument(text='hello world', number=i),
86+
return_type=MyDocument,
8487
):
8588
assert doc.text == f'hello world {i}'
8689
i += 1
@@ -111,27 +114,66 @@ async def test_streaming_custom_response(protocol, endpoint, include_gateway):
111114
i += 1
112115

113116

117+
class WaitStreamExecutor(Executor):
118+
@requests(on='/hello')
119+
async def task(self, doc: MyDocument, **kwargs) -> MyDocument:
120+
for i in range(5):
121+
yield MyDocument(text=f'{doc.text} {doc.number + i}')
122+
await asyncio.sleep(0.5)
123+
124+
125+
@pytest.mark.asyncio
126+
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
127+
@pytest.mark.parametrize('include_gateway', [False, True])
128+
async def test_streaming_delay(protocol, include_gateway):
129+
from jina import Deployment
130+
131+
port = random_port()
132+
133+
with Deployment(
134+
uses=WaitStreamExecutor,
135+
timeout_ready=-1,
136+
protocol=protocol,
137+
port=port,
138+
include_gateway=include_gateway,
139+
):
140+
client = Client(port=port, protocol=protocol, asyncio=True)
141+
i = 0
142+
start_time = time.time()
143+
async for doc in client.stream_doc(
144+
on='/hello',
145+
inputs=MyDocument(text='hello world', number=i),
146+
return_type=MyDocument,
147+
):
148+
assert doc.text == f'hello world {i}'
149+
i += 1
150+
151+
# 0.5 seconds between each request + 0.5 seconds tolerance interval
152+
assert time.time() - start_time < (0.5 * i) + 0.5
153+
154+
114155
@pytest.mark.asyncio
115156
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
116157
@pytest.mark.parametrize('endpoint', ['task1', 'task2', 'task3'])
117158
async def test_streaming_custom_response_flow_one_executor(protocol, endpoint):
118159
port = random_port()
119160

120161
with Flow(
121-
protocol=protocol,
122-
cors=True,
123-
port=port,
162+
protocol=protocol,
163+
cors=True,
164+
port=port,
124165
).add(uses=CustomResponseExecutor):
125166
client = Client(port=port, protocol=protocol, cors=True, asyncio=True)
126167
i = 0
127168
async for doc in client.stream_doc(
128-
on=f'/{endpoint}',
129-
inputs=MyDocument(text='hello world', number=5),
130-
return_type=OutputDocument,
169+
on=f'/{endpoint}',
170+
inputs=MyDocument(text='hello world', number=5),
171+
return_type=OutputDocument,
131172
):
132173
assert doc.text == f'hello world 5-{i}-{endpoint}'
133174
i += 1
134175

176+
135177
class Executor1(Executor):
136178
@requests
137179
def generator(self, doc: MyDocument, **kwargs) -> MyDocument:

tests/integration/streaming/test_streaming.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import asyncio
2+
import time
3+
14
import pytest
25

36
from jina import Client, Deployment, Executor, requests
@@ -29,11 +32,10 @@ async def test_streaming_deployment(protocol, include_gateway):
2932
uses=MyExecutor,
3033
timeout_ready=-1,
3134
protocol=protocol,
32-
cors=True,
3335
port=port,
3436
include_gateway=include_gateway,
3537
):
36-
client = Client(port=port, protocol=protocol, cors=True, asyncio=True)
38+
client = Client(port=port, protocol=protocol, asyncio=True)
3739
i = 0
3840
async for doc in client.stream_doc(
3941
on='/hello', inputs=Document(text='hello world')
@@ -42,6 +44,42 @@ async def test_streaming_deployment(protocol, include_gateway):
4244
i += 1
4345

4446

47+
class WaitStreamExecutor(Executor):
48+
@requests(on='/hello')
49+
async def task(self, doc: Document, **kwargs):
50+
for i in range(5):
51+
yield Document(text=f'{doc.text} {i}')
52+
await asyncio.sleep(0.5)
53+
54+
55+
@pytest.mark.asyncio
56+
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
57+
@pytest.mark.parametrize('include_gateway', [False, True])
58+
async def test_streaming_delay(protocol, include_gateway):
59+
from jina import Deployment
60+
61+
port = random_port()
62+
63+
with Deployment(
64+
uses=WaitStreamExecutor,
65+
timeout_ready=-1,
66+
protocol=protocol,
67+
port=port,
68+
include_gateway=include_gateway,
69+
):
70+
client = Client(port=port, protocol=protocol, asyncio=True)
71+
i = 0
72+
start_time = time.time()
73+
async for doc in client.stream_doc(
74+
on='/hello', inputs=Document(text='hello world')
75+
):
76+
assert doc.text == f'hello world {i}'
77+
i += 1
78+
79+
# 0.5 seconds between each request + 0.5 seconds tolerance interval
80+
assert time.time() - start_time < (0.5 * i) + 0.5
81+
82+
4583
@pytest.mark.asyncio
4684
@pytest.mark.parametrize('protocol', ['grpc'])
4785
async def test_streaming_client_non_gen_endpoint(protocol):
@@ -53,11 +91,10 @@ async def test_streaming_client_non_gen_endpoint(protocol):
5391
uses=MyExecutor,
5492
timeout_ready=-1,
5593
protocol=protocol,
56-
cors=True,
5794
port=port,
5895
include_gateway=False,
5996
):
60-
client = Client(port=port, protocol=protocol, cors=True, asyncio=True)
97+
client = Client(port=port, protocol=protocol, asyncio=True)
6198
i = 0
6299
with pytest.raises(BadServer):
63100
async for _ in client.stream_doc(
@@ -67,7 +104,6 @@ async def test_streaming_client_non_gen_endpoint(protocol):
67104

68105

69106
def test_invalid_executor():
70-
71107
with pytest.raises(RuntimeError) as exc_info:
72108

73109
class InvalidExecutor3(Executor):

0 commit comments

Comments
 (0)