Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion rock/admin/entrypoints/sandbox_proxy_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import logging
from typing import Any

from fastapi import APIRouter, File, Form, Request, UploadFile, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Body, File, Form, Request, UploadFile, WebSocket, WebSocketDisconnect

from rock.actions import (
BashObservation,
Expand Down Expand Up @@ -36,6 +37,7 @@ def set_sandbox_proxy_service(service: SandboxProxyService):
global sandbox_proxy_service
sandbox_proxy_service = service


@sandbox_proxy_router.post("/execute")
@handle_exceptions(error_message="execute command failed")
async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]:
Expand Down Expand Up @@ -127,3 +129,15 @@ async def websocket_proxy(websocket: WebSocket, id: str, path: str = ""):
async def get_token():
result = await asyncio.to_thread(sandbox_proxy_service.gen_oss_sts_token)
return RockResponse(result=result)


@sandbox_proxy_router.post("/sandboxes/{sandbox_id}/proxy")
@sandbox_proxy_router.post("/sandboxes/{sandbox_id}/proxy/{path:path}")
@handle_exceptions(error_message="post proxy failed")
async def post_proxy(
sandbox_id: str,
request: Request,
path: str = "",
body: dict[str, Any] = Body(None),
):
return await sandbox_proxy_service.post_proxy(sandbox_id, path, body, request.headers)
92 changes: 91 additions & 1 deletion rock/sandbox/service/sandbox_proxy_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio # noqa: I001
import json
import time
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.datastructures import Headers

import httpx
import oss2
import websockets
from aliyunsdkcore import client
from aliyunsdkcore.request import CommonRequest
from fastapi import UploadFile
from fastapi import Response, UploadFile
from starlette.status import HTTP_504_GATEWAY_TIMEOUT

from rock import env_vars
Expand Down Expand Up @@ -411,3 +413,91 @@ def _matches_query_params(self, sandbox_info: SandboxInfo, query_params: Sandbox
if filter_key not in sandbox_info or sandbox_info[filter_key] != filter_value:
return False
return True

async def post_proxy(
self,
sandbox_id: str,
target_path: str,
body: dict | None,
headers: Headers,
) -> JSONResponse | StreamingResponse | Response:
"""HTTP POST proxy that supports both streaming (SSE) and non-streaming responses."""

EXCLUDED_HEADERS = {"host", "content-length", "transfer-encoding"}

def filter_headers(raw_headers: Headers) -> dict:
return {k: v for k, v in raw_headers.items() if k.lower() not in EXCLUDED_HEADERS}

status_list = await self.get_service_status(sandbox_id)
service_status = ServiceStatus.from_dict(status_list[0])

host_ip = status_list[0].get("host_ip")
port = service_status.get_mapped_port(Port.SERVER)
target_url = f"http://{host_ip}:{port}/{target_path}"

request_headers = filter_headers(headers)
payload = body or {}

client = httpx.AsyncClient(timeout=httpx.Timeout(None))

try:
resp = await client.send(
client.build_request(
method="POST",
url=target_url,
json=payload,
headers=request_headers,
timeout=120,
),
stream=True,
)
except Exception:
await client.aclose()
raise

content_type = resp.headers.get("content-type", "")
is_sse = "text/event-stream" in content_type
response_headers = filter_headers(resp.headers)

if is_sse:

async def event_stream():
"""Forward upstream bytes to downstream as soon as they arrive."""
try:
if resp.status_code >= 400:
yield await resp.aread()
return

async for chunk in resp.aiter_bytes():
if chunk:
yield chunk
finally:
await resp.aclose()
await client.aclose()

return StreamingResponse(
event_stream(),
status_code=resp.status_code,
media_type="text/event-stream",
headers=response_headers,
)

try:
raw_content = await resp.aread()

if "application/json" in content_type:
return JSONResponse(
status_code=resp.status_code,
content=resp.json(),
headers=response_headers,
)

return Response(
status_code=resp.status_code,
content=raw_content,
media_type=content_type or "application/octet-stream",
headers=response_headers,
)
finally:
await resp.aclose()
await client.aclose()
184 changes: 184 additions & 0 deletions tests/unit/sandbox/test_sandbox_http_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import asyncio
import json

import pytest
from starlette.datastructures import Headers

from rock.admin.proto.request import (
SandboxBashAction as BashAction,
)
from rock.admin.proto.request import (
SandboxCloseBashSessionRequest as CloseBashSessionRequest,
)
from rock.admin.proto.request import (
SandboxCreateBashSessionRequest as CreateSessionRequest,
)
from rock.deployments.config import DockerDeploymentConfig
from rock.logger import init_logger
from rock.sandbox.sandbox_manager import SandboxManager
from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService
from tests.unit.conftest import check_sandbox_status_until_alive

logger = init_logger(__name__)

ECHO_SERVER_SCRIPT = r"""
import json
from http.server import HTTPServer, BaseHTTPRequestHandler

class Handler(BaseHTTPRequestHandler):
def do_POST(self):
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length else b"{}"
data = json.loads(body)

if self.path == "/stream":
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for i in range(3):
chunk = {"index": i, "echo": data}
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode())
self.wfile.flush()
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
else:
response = {"path": self.path, "echo": data}
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode())

def log_message(self, format, *args):
pass

HTTPServer(("0.0.0.0", 8080), Handler).serve_forever()
"""

SESSION_NAME = "test"


async def start_echo_server_in_sandbox(
sandbox_proxy_service: SandboxProxyService,
sandbox_id: str,
) -> None:
"""Start an echo HTTP server inside the sandbox container via create_session + run_in_session."""

create_req = CreateSessionRequest(
session=SESSION_NAME,
sandbox_id=sandbox_id,
)
await sandbox_proxy_service.create_session(create_req)

# Write echo server script to file
write_action = BashAction(
action_type="bash",
sandbox_id=sandbox_id,
session=SESSION_NAME,
command="cat > /tmp/echo_server.py << 'PYEOF'\n" + ECHO_SERVER_SCRIPT.strip() + "\nPYEOF",
)
await sandbox_proxy_service.run_in_session(write_action)

# Start echo server in background
start_action = BashAction(
action_type="bash",
sandbox_id=sandbox_id,
session=SESSION_NAME,
command="nohup python3 /tmp/echo_server.py > /tmp/server.log 2>&1 & echo $!",
)
start_result = await sandbox_proxy_service.run_in_session(start_action)
logger.info(f"echo server started, result: {start_result}")

# Wait for server to be ready
await asyncio.sleep(2)


@pytest.mark.need_ray
@pytest.mark.asyncio
async def test_post_proxy(sandbox_manager: SandboxManager, sandbox_proxy_service: SandboxProxyService):
response = await sandbox_manager.start_async(DockerDeploymentConfig(cpus=0.5, memory="1g"))
sandbox_id = response.sandbox_id
await check_sandbox_status_until_alive(sandbox_manager, sandbox_id)

try:
await start_echo_server_in_sandbox(sandbox_proxy_service, sandbox_id)

mock_headers = Headers({"content-type": "application/json"})

# Test with path and body
result = await sandbox_proxy_service.post_proxy(
sandbox_id=sandbox_id,
target_path="api/test",
body={"hello": "world"},
headers=mock_headers,
)
assert result.status_code == 200
response_body = json.loads(result.body)
assert response_body["path"] == "/api/test"
assert response_body["echo"] == {"hello": "world"}

# Test without path
result = await sandbox_proxy_service.post_proxy(
sandbox_id=sandbox_id,
target_path="",
body={"key": "value"},
headers=mock_headers,
)
assert result.status_code == 200
response_body = json.loads(result.body)
assert response_body["echo"] == {"key": "value"}

# Test with body as None
result = await sandbox_proxy_service.post_proxy(
sandbox_id=sandbox_id,
target_path="health",
body=None,
headers=mock_headers,
)
assert result.status_code == 200
response_body = json.loads(result.body)
assert response_body["echo"] == {}
assert response_body["path"] == "/health"

# Test SSE streaming response
result = await sandbox_proxy_service.post_proxy(
sandbox_id=sandbox_id,
target_path="stream",
body={"msg": "hello"},
headers=mock_headers,
)
assert result.status_code == 200
assert result.media_type == "text/event-stream"

# Collect all streamed chunks
chunks = []
async for chunk in result.body_iterator:
if isinstance(chunk, bytes):
chunk = chunk.decode()
chunks.append(chunk)

full_response = "".join(chunks)
logger.info(f"streaming response: {full_response}")

# Verify SSE format and content
assert "data: [DONE]" in full_response

events = [line for line in full_response.strip().split("\n\n") if line.startswith("data: ")]
data_events = [e for e in events if e != "data: [DONE]"]

assert len(data_events) == 3
for i, event in enumerate(data_events):
event_data = json.loads(event.replace("data: ", ""))
assert event_data["index"] == i
assert event_data["echo"] == {"msg": "hello"}

finally:
try:
close_req = CloseBashSessionRequest(
session=SESSION_NAME,
sandbox_id=sandbox_id,
)
await sandbox_proxy_service.close_session(close_req)
except Exception:
pass
await sandbox_manager.stop(sandbox_id)
6 changes: 4 additions & 2 deletions tests/unit/sandbox/test_sandbox_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def test_get_system_resource_info(sandbox_manager):
@pytest.mark.asyncio
async def test_get_status_state(sandbox_manager):
response = await sandbox_manager.start_async(
DockerDeploymentConfig(),
DockerDeploymentConfig(cpus=0.5, memory="1g"),
)
sandbox_id = response.sandbox_id
await check_sandbox_status_until_alive(sandbox_manager, sandbox_id)
Expand All @@ -157,7 +157,9 @@ async def test_get_status_state(sandbox_manager):
async def test_sandbox_start_with_sandbox_id(sandbox_manager):
try:
sandbox_id = uuid.uuid4().hex
response = await sandbox_manager.start_async(DockerDeploymentConfig(container_name=sandbox_id))
response = await sandbox_manager.start_async(
DockerDeploymentConfig(container_name=sandbox_id, cpus=0.5, memory="1g")
)
assert response.sandbox_id == sandbox_id
await check_sandbox_status_until_alive(sandbox_manager, sandbox_id)
with pytest.raises(BadRequestRockError) as e:
Expand Down
Loading
Loading