Skip to content

Commit

Permalink
Support stream response
Browse files Browse the repository at this point in the history
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Mar 8, 2025
1 parent 8ff8654 commit 8b891c9
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 71 deletions.
32 changes: 29 additions & 3 deletions python/ray/dashboard/subprocesses/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from ray.dashboard.subprocesses.utils import (
module_logging_filename,
ResponseType,
)

"""
Expand Down Expand Up @@ -215,14 +216,18 @@ async def _do_periodic_health_check(self):
await asyncio.sleep(1)

async def proxy_request(
self, request: aiohttp.web.Request, websocket=False
self, request: aiohttp.web.Request, resp_type: ResponseType = "http"
) -> aiohttp.web.StreamResponse:
"""
Sends a new request to the subprocess and returns the response.
"""
if not websocket:
if resp_type == "http":
return await self.proxy_http(request)
return await self.proxy_websocket(request)
if resp_type == "stream":
return await self.proxy_stream(request)
if resp_type == "websocket":
return await self.proxy_websocket(request)
raise ValueError(f"Unknown response type: {resp_type}")

async def proxy_http(
self, request: aiohttp.web.Request
Expand All @@ -242,6 +247,27 @@ async def proxy_http(
status=resp.status, headers=filter_headers(resp.headers), body=resp_body
)

async def proxy_stream(
self, request: aiohttp.web.Request
) -> aiohttp.web.StreamResponse:
"""
Proxy handler for streaming responses.
It forwards the method, query string, and body to the backend.
"""
url = f"http://localhost{request.path_qs}"
body = await request.read()

async with self.session.request(
request.method, url, data=body, headers=request.headers
) as backend_resp:
client_resp = aiohttp.web.StreamResponse(status=backend_resp.status)
await client_resp.prepare(request)

async for chunk in backend_resp.content.iter_chunked(1024):
await client_resp.write(chunk)
await client_resp.write_eof()
return client_resp

async def proxy_websocket(
self, request: aiohttp.web.Request
) -> aiohttp.web.StreamResponse:
Expand Down
47 changes: 45 additions & 2 deletions python/ray/dashboard/subprocesses/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import collections
import inspect
import functools

from ray.dashboard.optional_deps import aiohttp

from ray.dashboard.routes import BaseRouteTable
from ray.dashboard.subprocesses.handle import SubprocessModuleHandle
from ray.dashboard.subprocesses.utils import ResponseType
from ray.dashboard.subprocesses.module import SubprocessModule
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,7 +60,7 @@ def predicate(o):
cls._bind_map[h.__route_method__][h.__route_path__].instance = instance

@classmethod
def _register_route(cls, method, path, websocket=False, **kwargs):
def _register_route(cls, method, path, resp_type: ResponseType = "http", **kwargs):
"""
Register a route to the module and return the decorated handler.
"""
Expand All @@ -77,12 +80,21 @@ def _wrapper(handler):

cls._bind_map[method][path] = bind_info

if resp_type == "http":
handler = handler
elif resp_type == "stream":
handler = decorate_stream_handler(handler)
elif resp_type == "websocket":
handler = decorate_websocket_handler(handler)
else:
raise ValueError(f"Unknown resp_type: {resp_type}")

async def parent_side_handler(
request: aiohttp.web.Request,
) -> aiohttp.web.Response:
bind_info = cls._bind_map[method][path]
subprocess_module_handle = bind_info.instance
return await subprocess_module_handle.proxy_request(request, websocket)
return await subprocess_module_handle.proxy_request(request, resp_type)

# Used in bind().
handler.__route_method__ = method
Expand All @@ -96,3 +108,34 @@ async def parent_side_handler(
return handler

return _wrapper


def decorate_stream_handler(handler):
assert inspect.isasyncgenfunction(handler)

@functools.wraps(handler)
async def wrapper(
self: SubprocessModule, req: aiohttp.web.Request, *args, **kwargs
):
resp = aiohttp.web.StreamResponse()
iter = handler(self, req, *args, **kwargs)
try:
chunk = await iter.__anext__()
await resp.prepare(req)
await resp.write(chunk)
except StopAsyncIteration:
pass
try:
async for chunk in iter:
await resp.write(chunk)
except aiohttp.web.HTTPException as e:
await resp.write(e.reason.encode())
finally:
await resp.write_eof()
return resp

return wrapper


def decorate_websocket_handler(handler):
pass
62 changes: 31 additions & 31 deletions python/ray/dashboard/subprocesses/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,37 @@ async def test_http_server(aiohttp_client, default_module_config):
assert "you shall not pass" in await response.text()


# async def test_streamed_iota(aiohttp_client, default_module_config):
# # TODO(ryw): also test streams that raise exceptions.
# app = await start_http_server_app(default_module_config, [TestModule])
# client = await aiohttp_client(app)
#
# response = await client.post("/streamed_iota", data=b"10")
# assert response.status == 200
# assert await response.text() == "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n"
#
#
# async def test_streamed_iota_with_503(aiohttp_client, default_module_config):
# app = await start_http_server_app(default_module_config, [TestModule])
# client = await aiohttp_client(app)
#
# # Server behavior: sends 200 OK with 0-9, then an error message.
# response = await client.post("/streamed_iota_with_503", data=b"10")
# assert response.headers["Transfer-Encoding"] == "chunked"
# assert response.status == 200
# txt = await response.text()
# assert txt == "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\nService Unavailable after 10 numbers"
#
#
# async def test_streamed_error_before_yielding(aiohttp_client, default_module_config):
# app = await start_http_server_app(default_module_config, [TestModule])
# client = await aiohttp_client(app)
#
# response = await client.post("/streamed_401", data=b"")
# assert response.status == 401
# assert await response.text() == "401: Unauthorized although I am not a teapot"
#
#
async def test_streamed_iota(aiohttp_client, default_module_config):
# TODO(ryw): also test streams that raise exceptions.
app = await start_http_server_app(default_module_config, [TestModule])
client = await aiohttp_client(app)

response = await client.post("/streamed_iota", data=b"10")
assert response.status == 200
assert await response.text() == "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n"


async def test_streamed_iota_with_503(aiohttp_client, default_module_config):
app = await start_http_server_app(default_module_config, [TestModule])
client = await aiohttp_client(app)

# Server behavior: sends 200 OK with 0-9, then an error message.
response = await client.post("/streamed_iota_with_503", data=b"10")
assert response.headers["Transfer-Encoding"] == "chunked"
assert response.status == 200
txt = await response.text()
assert txt == "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\nService Unavailable after 10 numbers"


async def test_streamed_error_before_yielding(aiohttp_client, default_module_config):
app = await start_http_server_app(default_module_config, [TestModule])
client = await aiohttp_client(app)

response = await client.post("/streamed_401", data=b"")
assert response.status == 401
assert await response.text() == "401: Unauthorized although I am not a teapot"


async def test_kill_self(aiohttp_client, default_module_config):
"""
If a module died, all pending requests should be failed, and the module should be
Expand Down
79 changes: 45 additions & 34 deletions python/ray/dashboard/subprocesses/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import signal
from typing import AsyncIterator

from ray.dashboard.optional_deps import aiohttp

Expand Down Expand Up @@ -56,40 +57,50 @@ async def make_error_403(self, req: aiohttp.web.Request) -> aiohttp.web.Response
# https://github.com/ray-project/ray/pull/49732#discussion_r1919292428
raise aiohttp.web.HTTPForbidden(reason="you shall not pass")

# @routes.post("/streamed_iota", streaming=True)
# async def streamed_iota(self, request_body: bytes) -> AsyncIterator[bytes]:
# """
# Streams the numbers 0 to N.
# """
# n = int(request_body)
# for i in range(n):
# await asyncio.sleep(0.001)
# yield f"{i}\n".encode()
#
# @routes.post("/streamed_iota_with_503", streaming=True)
# async def streamed_iota_with_503(self, request_body: bytes) -> AsyncIterator[bytes]:
# """
# Streams the numbers 0 to N, then raises an error.
# """
# n = int(request_body)
# for i in range(n):
# await asyncio.sleep(0.001)
# yield f"{i}\n".encode()
# raise aiohttp.web.HTTPServiceUnavailable(
# reason=f"Service Unavailable after {n} numbers"
# )
#
# @routes.post("/streamed_401", streaming=True)
# async def streamed_401(self, request_body: bytes) -> AsyncIterator[bytes]:
# """
# Raises an error directly in a streamed handler before yielding any data.
# """
# raise aiohttp.web.HTTPUnauthorized(
# reason="Unauthorized although I am not a teapot"
# )
# # To make sure Python treats this method as an async generator, we yield something.
# yield b"Hello, World"
#
@routes.post("/streamed_iota", resp_type="stream")
async def streamed_iota(self, req: aiohttp.web.Request) -> AsyncIterator[bytes]:
"""
Streams the numbers 0 to N.
"""
request_body = await req.text()
try:
n = int(request_body)
except ValueError:
raise aiohttp.web.HTTPBadRequest(reason="Request body must be an integer")
for i in range(n):
await asyncio.sleep(0.001)
yield f"{i}\n".encode()

@routes.post("/streamed_iota_with_503", resp_type="stream")
async def streamed_iota_with_503(
self, req: aiohttp.web.Request
) -> AsyncIterator[bytes]:
"""
Streams the numbers 0 to N, then raises an error.
"""
request_body = await req.text()
try:
n = int(request_body)
except ValueError:
raise aiohttp.web.HTTPBadRequest(reason="Request body must be an integer")
for i in range(n):
await asyncio.sleep(0.001)
yield f"{i}\n".encode()
raise aiohttp.web.HTTPServiceUnavailable(
reason=f"Service Unavailable after {n} numbers"
)

@routes.post("/streamed_401", resp_type="stream")
async def streamed_401(self, req: aiohttp.web.Request) -> AsyncIterator[bytes]:
"""
Raises an error directly in a streamed handler before yielding any data.
"""
raise aiohttp.web.HTTPUnauthorized(
reason="Unauthorized although I am not a teapot"
)
# To make sure Python treats this method as an async generator, we yield something.
yield b"Hello, World"

@routes.post("/run_forever")
async def run_forever(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
while True:
Expand Down
4 changes: 3 additions & 1 deletion python/ray/dashboard/subprocesses/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from typing import TypeVar
from typing import TypeVar, Literal

K = TypeVar("K")
V = TypeVar("V")

ResponseType = Literal["http", "stream", "websocket"]


def module_logging_filename(
module_name: str, incarnation: int, logging_filename: str
Expand Down

0 comments on commit 8b891c9

Please sign in to comment.