diff --git a/rock/admin/entrypoints/sandbox_proxy_api.py b/rock/admin/entrypoints/sandbox_proxy_api.py index 8d9d40b42..2770fa7d2 100644 --- a/rock/admin/entrypoints/sandbox_proxy_api.py +++ b/rock/admin/entrypoints/sandbox_proxy_api.py @@ -1,7 +1,8 @@ import asyncio import logging +from typing import Any -from fastapi import APIRouter, File, Form, Request, UploadFile, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Body, File, Form, Request, UploadFile, WebSocket, WebSocketDisconnect from rock.actions import ( BashObservation, @@ -36,6 +37,7 @@ def set_sandbox_proxy_service(service: SandboxProxyService): global sandbox_proxy_service sandbox_proxy_service = service + @sandbox_proxy_router.post("/execute") @handle_exceptions(error_message="execute command failed") async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]: @@ -127,3 +129,15 @@ async def websocket_proxy(websocket: WebSocket, id: str, path: str = ""): async def get_token(): result = await asyncio.to_thread(sandbox_proxy_service.gen_oss_sts_token) return RockResponse(result=result) + + +@sandbox_proxy_router.post("/sandboxes/{sandbox_id}/proxy") +@sandbox_proxy_router.post("/sandboxes/{sandbox_id}/proxy/{path:path}") +@handle_exceptions(error_message="post proxy failed") +async def post_proxy( + sandbox_id: str, + request: Request, + path: str = "", + body: dict[str, Any] = Body(None), +): + return await sandbox_proxy_service.post_proxy(sandbox_id, path, body, request.headers) diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index 1017bdfe4..49f121d3b 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -1,13 +1,15 @@ import asyncio # noqa: I001 import json import time +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.datastructures import Headers import httpx import oss2 import websockets from aliyunsdkcore import client from aliyunsdkcore.request import CommonRequest -from fastapi import UploadFile +from fastapi import Response, UploadFile from starlette.status import HTTP_504_GATEWAY_TIMEOUT from rock import env_vars @@ -411,3 +413,91 @@ def _matches_query_params(self, sandbox_info: SandboxInfo, query_params: Sandbox if filter_key not in sandbox_info or sandbox_info[filter_key] != filter_value: return False return True + + async def post_proxy( + self, + sandbox_id: str, + target_path: str, + body: dict | None, + headers: Headers, + ) -> JSONResponse | StreamingResponse | Response: + """HTTP POST proxy that supports both streaming (SSE) and non-streaming responses.""" + + EXCLUDED_HEADERS = {"host", "content-length", "transfer-encoding"} + + def filter_headers(raw_headers: Headers) -> dict: + return {k: v for k, v in raw_headers.items() if k.lower() not in EXCLUDED_HEADERS} + + status_list = await self.get_service_status(sandbox_id) + service_status = ServiceStatus.from_dict(status_list[0]) + + host_ip = status_list[0].get("host_ip") + port = service_status.get_mapped_port(Port.SERVER) + target_url = f"http://{host_ip}:{port}/{target_path}" + + request_headers = filter_headers(headers) + payload = body or {} + + client = httpx.AsyncClient(timeout=httpx.Timeout(None)) + + try: + resp = await client.send( + client.build_request( + method="POST", + url=target_url, + json=payload, + headers=request_headers, + timeout=120, + ), + stream=True, + ) + except Exception: + await client.aclose() + raise + + content_type = resp.headers.get("content-type", "") + is_sse = "text/event-stream" in content_type + response_headers = filter_headers(resp.headers) + + if is_sse: + + async def event_stream(): + """Forward upstream bytes to downstream as soon as they arrive.""" + try: + if resp.status_code >= 400: + yield await resp.aread() + return + + async for chunk in resp.aiter_bytes(): + if chunk: + yield chunk + finally: + await resp.aclose() + await client.aclose() + + return StreamingResponse( + event_stream(), + status_code=resp.status_code, + media_type="text/event-stream", + headers=response_headers, + ) + + try: + raw_content = await resp.aread() + + if "application/json" in content_type: + return JSONResponse( + status_code=resp.status_code, + content=resp.json(), + headers=response_headers, + ) + + return Response( + status_code=resp.status_code, + content=raw_content, + media_type=content_type or "application/octet-stream", + headers=response_headers, + ) + finally: + await resp.aclose() + await client.aclose() diff --git a/tests/unit/sandbox/test_sandbox_http_proxy.py b/tests/unit/sandbox/test_sandbox_http_proxy.py new file mode 100644 index 000000000..8095675f8 --- /dev/null +++ b/tests/unit/sandbox/test_sandbox_http_proxy.py @@ -0,0 +1,184 @@ +import asyncio +import json + +import pytest +from starlette.datastructures import Headers + +from rock.admin.proto.request import ( + SandboxBashAction as BashAction, +) +from rock.admin.proto.request import ( + SandboxCloseBashSessionRequest as CloseBashSessionRequest, +) +from rock.admin.proto.request import ( + SandboxCreateBashSessionRequest as CreateSessionRequest, +) +from rock.deployments.config import DockerDeploymentConfig +from rock.logger import init_logger +from rock.sandbox.sandbox_manager import SandboxManager +from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService +from tests.unit.conftest import check_sandbox_status_until_alive + +logger = init_logger(__name__) + +ECHO_SERVER_SCRIPT = r""" +import json +from http.server import HTTPServer, BaseHTTPRequestHandler + +class Handler(BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length else b"{}" + data = json.loads(body) + + if self.path == "/stream": + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.end_headers() + for i in range(3): + chunk = {"index": i, "echo": data} + self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode()) + self.wfile.flush() + self.wfile.write(b"data: [DONE]\n\n") + self.wfile.flush() + else: + response = {"path": self.path, "echo": data} + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + + def log_message(self, format, *args): + pass + +HTTPServer(("0.0.0.0", 8080), Handler).serve_forever() +""" + +SESSION_NAME = "test" + + +async def start_echo_server_in_sandbox( + sandbox_proxy_service: SandboxProxyService, + sandbox_id: str, +) -> None: + """Start an echo HTTP server inside the sandbox container via create_session + run_in_session.""" + + create_req = CreateSessionRequest( + session=SESSION_NAME, + sandbox_id=sandbox_id, + ) + await sandbox_proxy_service.create_session(create_req) + + # Write echo server script to file + write_action = BashAction( + action_type="bash", + sandbox_id=sandbox_id, + session=SESSION_NAME, + command="cat > /tmp/echo_server.py << 'PYEOF'\n" + ECHO_SERVER_SCRIPT.strip() + "\nPYEOF", + ) + await sandbox_proxy_service.run_in_session(write_action) + + # Start echo server in background + start_action = BashAction( + action_type="bash", + sandbox_id=sandbox_id, + session=SESSION_NAME, + command="nohup python3 /tmp/echo_server.py > /tmp/server.log 2>&1 & echo $!", + ) + start_result = await sandbox_proxy_service.run_in_session(start_action) + logger.info(f"echo server started, result: {start_result}") + + # Wait for server to be ready + await asyncio.sleep(2) + + +@pytest.mark.need_ray +@pytest.mark.asyncio +async def test_post_proxy(sandbox_manager: SandboxManager, sandbox_proxy_service: SandboxProxyService): + response = await sandbox_manager.start_async(DockerDeploymentConfig(cpus=0.5, memory="1g")) + sandbox_id = response.sandbox_id + await check_sandbox_status_until_alive(sandbox_manager, sandbox_id) + + try: + await start_echo_server_in_sandbox(sandbox_proxy_service, sandbox_id) + + mock_headers = Headers({"content-type": "application/json"}) + + # Test with path and body + result = await sandbox_proxy_service.post_proxy( + sandbox_id=sandbox_id, + target_path="api/test", + body={"hello": "world"}, + headers=mock_headers, + ) + assert result.status_code == 200 + response_body = json.loads(result.body) + assert response_body["path"] == "/api/test" + assert response_body["echo"] == {"hello": "world"} + + # Test without path + result = await sandbox_proxy_service.post_proxy( + sandbox_id=sandbox_id, + target_path="", + body={"key": "value"}, + headers=mock_headers, + ) + assert result.status_code == 200 + response_body = json.loads(result.body) + assert response_body["echo"] == {"key": "value"} + + # Test with body as None + result = await sandbox_proxy_service.post_proxy( + sandbox_id=sandbox_id, + target_path="health", + body=None, + headers=mock_headers, + ) + assert result.status_code == 200 + response_body = json.loads(result.body) + assert response_body["echo"] == {} + assert response_body["path"] == "/health" + + # Test SSE streaming response + result = await sandbox_proxy_service.post_proxy( + sandbox_id=sandbox_id, + target_path="stream", + body={"msg": "hello"}, + headers=mock_headers, + ) + assert result.status_code == 200 + assert result.media_type == "text/event-stream" + + # Collect all streamed chunks + chunks = [] + async for chunk in result.body_iterator: + if isinstance(chunk, bytes): + chunk = chunk.decode() + chunks.append(chunk) + + full_response = "".join(chunks) + logger.info(f"streaming response: {full_response}") + + # Verify SSE format and content + assert "data: [DONE]" in full_response + + events = [line for line in full_response.strip().split("\n\n") if line.startswith("data: ")] + data_events = [e for e in events if e != "data: [DONE]"] + + assert len(data_events) == 3 + for i, event in enumerate(data_events): + event_data = json.loads(event.replace("data: ", "")) + assert event_data["index"] == i + assert event_data["echo"] == {"msg": "hello"} + + finally: + try: + close_req = CloseBashSessionRequest( + session=SESSION_NAME, + sandbox_id=sandbox_id, + ) + await sandbox_proxy_service.close_session(close_req) + except Exception: + pass + await sandbox_manager.stop(sandbox_id) diff --git a/tests/unit/sandbox/test_sandbox_manager.py b/tests/unit/sandbox/test_sandbox_manager.py index f8a66d51e..657e0f3ab 100644 --- a/tests/unit/sandbox/test_sandbox_manager.py +++ b/tests/unit/sandbox/test_sandbox_manager.py @@ -143,7 +143,7 @@ async def test_get_system_resource_info(sandbox_manager): @pytest.mark.asyncio async def test_get_status_state(sandbox_manager): response = await sandbox_manager.start_async( - DockerDeploymentConfig(), + DockerDeploymentConfig(cpus=0.5, memory="1g"), ) sandbox_id = response.sandbox_id await check_sandbox_status_until_alive(sandbox_manager, sandbox_id) @@ -157,7 +157,9 @@ async def test_get_status_state(sandbox_manager): async def test_sandbox_start_with_sandbox_id(sandbox_manager): try: sandbox_id = uuid.uuid4().hex - response = await sandbox_manager.start_async(DockerDeploymentConfig(container_name=sandbox_id)) + response = await sandbox_manager.start_async( + DockerDeploymentConfig(container_name=sandbox_id, cpus=0.5, memory="1g") + ) assert response.sandbox_id == sandbox_id await check_sandbox_status_until_alive(sandbox_manager, sandbox_id) with pytest.raises(BadRequestRockError) as e: diff --git a/tests/unit/sandbox/test_sandbox_proxy_router.py b/tests/unit/sandbox/test_sandbox_proxy_router.py new file mode 100644 index 000000000..60abf7d56 --- /dev/null +++ b/tests/unit/sandbox/test_sandbox_proxy_router.py @@ -0,0 +1,54 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from rock.admin.entrypoints.sandbox_proxy_api import sandbox_proxy_router, set_sandbox_proxy_service + + +@pytest.fixture +def app(): + mock_service = MagicMock() + mock_service.post_proxy = AsyncMock(return_value={"ok": True}) + set_sandbox_proxy_service(mock_service) + + app = FastAPI() + app.include_router(sandbox_proxy_router) + return app, mock_service + + +@pytest.mark.asyncio +async def test_post_proxy_path_parsing(app): + app, mock_service = app + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # No path + await client.post("/sandboxes/sandbox-id/proxy", json={"key": "value"}) + args = mock_service.post_proxy.call_args + assert args[0][0] == "sandbox-id" + assert args[0][1] == "" + assert args[0][2] == {"key": "value"} + mock_service.post_proxy.reset_mock() + + # Single path segment + await client.post("/sandboxes/sandbox-id/proxy/health", json={}) + args = mock_service.post_proxy.call_args + assert args[0][0] == "sandbox-id" + assert args[0][1] == "health" + mock_service.post_proxy.reset_mock() + + # Nested path + await client.post("/sandboxes/sandbox-id/proxy/api/v1/chat", json={"msg": "hi"}) + args = mock_service.post_proxy.call_args + assert args[0][0] == "sandbox-id" + assert args[0][1] == "api/v1/chat" + assert args[0][2] == {"msg": "hi"} + mock_service.post_proxy.reset_mock() + + # Deep nested path + await client.post("/sandboxes/sandbox-id/proxy/a/b/c/d", json=None) + args = mock_service.post_proxy.call_args + assert args[0][0] == "sandbox-id" + assert args[0][1] == "a/b/c/d"