diff --git a/rock/admin/main.py b/rock/admin/main.py index 02569661b..a12588bae 100644 --- a/rock/admin/main.py +++ b/rock/admin/main.py @@ -22,6 +22,7 @@ from rock.config import RockConfig from rock.logger import init_logger from rock.sandbox.gem_manager import GemManager +from rock.sandbox.operator.factory import OperatorContext, OperatorFactory from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService from rock.sandbox.service.warmup_service import WarmupService from rock.utils import EAGLE_EYE_TRACE_ID, sandbox_id_ctx_var, trace_id_ctx_var @@ -70,6 +71,13 @@ async def lifespan(app: FastAPI): ray_service = RayService(rock_config.ray) ray_service.init() + # create operator using factory with context pattern + operator_context = OperatorContext( + runtime_config=rock_config.runtime, + ray_service=ray_service, + ) + operator = OperatorFactory.create_operator(operator_context) + # init service if rock_config.runtime.enable_auto_clear: sandbox_manager = GemManager( @@ -78,6 +86,7 @@ async def lifespan(app: FastAPI): ray_namespace=rock_config.ray.namespace, ray_service=ray_service, enable_runtime_auto_clear=True, + operator=operator, ) else: sandbox_manager = GemManager( @@ -86,6 +95,7 @@ async def lifespan(app: FastAPI): ray_namespace=rock_config.ray.namespace, ray_service=ray_service, enable_runtime_auto_clear=False, + operator=operator, ) set_sandbox_manager(sandbox_manager) warmup_service = WarmupService(rock_config.warmup) diff --git a/rock/config.py b/rock/config.py index 7f8e758ba..884d95c52 100644 --- a/rock/config.py +++ b/rock/config.py @@ -110,6 +110,7 @@ class RuntimeConfig: project_root: str = field(default_factory=lambda: env_vars.ROCK_PROJECT_ROOT) python_env_path: str = field(default_factory=lambda: env_vars.ROCK_PYTHON_ENV_PATH) envhub_db_url: str = field(default_factory=lambda: env_vars.ROCK_ENVHUB_DB_URL) + operator_type: str = "ray" standard_spec: StandardSpec = field(default_factory=StandardSpec) max_allowed_spec: StandardSpec = field(default_factory=lambda: StandardSpec(cpus=16, memory="64g")) diff --git a/rock/sandbox/gem_manager.py b/rock/sandbox/gem_manager.py index 224e634c8..2cb2aab1c 100644 --- a/rock/sandbox/gem_manager.py +++ b/rock/sandbox/gem_manager.py @@ -19,7 +19,6 @@ from rock.sandbox.sandbox_actor import SandboxActor from rock.sandbox.sandbox_manager import SandboxManager from rock.utils.providers import RedisProvider -from rock.admin.core.ray_service import RayService class GemManager(SandboxManager): @@ -30,8 +29,9 @@ def __init__( ray_namespace: str = env_vars.ROCK_RAY_NAMESPACE, ray_service: RayService | None = None, enable_runtime_auto_clear: bool = False, + operator=None, ): - super().__init__(rock_config, redis_provider, ray_namespace, ray_service, enable_runtime_auto_clear) + super().__init__(rock_config, redis_provider, ray_namespace, ray_service, enable_runtime_auto_clear, operator) async def env_make(self, env_id: str) -> EnvMakeResponse: config = DockerDeploymentConfig(image=env_vars.ROCK_ENVHUB_DEFAULT_DOCKER_IMAGE) diff --git a/rock/sandbox/operator/abstract.py b/rock/sandbox/operator/abstract.py index 77a2c2177..d80aaa554 100644 --- a/rock/sandbox/operator/abstract.py +++ b/rock/sandbox/operator/abstract.py @@ -2,11 +2,14 @@ from rock.actions.sandbox.sandbox_info import SandboxInfo from rock.deployments.config import DeploymentConfig +from rock.utils.providers.redis_provider import RedisProvider class AbstractOperator(ABC): + _redis_provider: RedisProvider | None = None + @abstractmethod - async def submit(self, config: DeploymentConfig) -> SandboxInfo: + async def submit(self, config: DeploymentConfig, user_info: dict = {}) -> SandboxInfo: ... @abstractmethod @@ -15,4 +18,7 @@ async def get_status(self, sandbox_id: str) -> SandboxInfo: @abstractmethod async def stop(self, sandbox_id: str) -> bool: - ... \ No newline at end of file + ... + + def set_redis_provider(self, redis_provider: RedisProvider): + self._redis_provider = redis_provider diff --git a/rock/sandbox/operator/factory.py b/rock/sandbox/operator/factory.py new file mode 100644 index 000000000..4dced003c --- /dev/null +++ b/rock/sandbox/operator/factory.py @@ -0,0 +1,60 @@ +"""Operator factory for creating operator instances based on configuration.""" + +from dataclasses import dataclass, field +from typing import Any + +from rock.admin.core.ray_service import RayService +from rock.config import RuntimeConfig +from rock.logger import init_logger +from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.operator.ray import RayOperator + +logger = init_logger(__name__) + + +@dataclass +class OperatorContext: + """Context object containing all dependencies needed for operator creation. + + This design pattern solves the parameter explosion problem by encapsulating + all dependencies in a single context object. New operator types can add their + dependencies to this context without changing the factory method signature. + """ + + runtime_config: RuntimeConfig + ray_service: RayService | None = None + # Future operator dependencies can be added here without breaking existing code + # kubernetes_client: Any | None = None + # docker_client: Any | None = None + extra_params: dict[str, Any] = field(default_factory=dict) + + +class OperatorFactory: + """Factory class for creating operator instances. + + Uses the Context Object pattern to avoid parameter explosion as new + operator types are added. + """ + + @staticmethod + def create_operator(context: OperatorContext) -> AbstractOperator: + """Create an operator instance based on the runtime configuration. + + Args: + context: OperatorContext containing all necessary dependencies + + Returns: + AbstractOperator: The created operator instance + + Raises: + ValueError: If operator_type is not supported or required dependencies are missing + """ + operator_type = context.runtime_config.operator_type.lower() + + if operator_type == "ray": + if context.ray_service is None: + raise ValueError("RayService is required for RayOperator") + logger.info("Creating RayOperator") + return RayOperator(ray_service=context.ray_service) + else: + raise ValueError(f"Unsupported operator type: {operator_type}. " f"Supported types: ray, kubernetes") diff --git a/rock/sandbox/operator/ray.py b/rock/sandbox/operator/ray.py index c312aca55..87793d4d6 100644 --- a/rock/sandbox/operator/ray.py +++ b/rock/sandbox/operator/ray.py @@ -1,14 +1,102 @@ +import json + +import ray + +from rock.actions.sandbox.response import State from rock.actions.sandbox.sandbox_info import SandboxInfo -from rock.deployments.config import DeploymentConfig +from rock.admin.core.ray_service import RayService +from rock.admin.core.redis_key import alive_sandbox_key +from rock.deployments.config import DockerDeploymentConfig +from rock.deployments.docker import DockerDeployment +from rock.deployments.status import ServiceStatus +from rock.logger import init_logger from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.sandbox_actor import SandboxActor +from rock.sdk.common.exceptions import BadRequestRockError +from rock.utils.format import parse_memory_size + +logger = init_logger(__name__) class RayOperator(AbstractOperator): - async def submit(self, config: DeploymentConfig) -> SandboxInfo: - return SandboxInfo(sandbox_id="test", host_name="test", host_ip="test") + def __init__(self, ray_service: RayService): + self._ray_service = ray_service + + def _get_actor_name(self, sandbox_id: str) -> str: + return f"sandbox-{sandbox_id}" + + async def create_actor(self, config: DockerDeploymentConfig): + actor_options = self._generate_actor_options(config) + deployment: DockerDeployment = config.get_deployment() + sandbox_actor = SandboxActor.options(**actor_options).remote(config, deployment) + return sandbox_actor + + def _generate_actor_options(self, config: DockerDeploymentConfig) -> dict: + actor_name = self._get_actor_name(config.container_name) + actor_options = {"name": actor_name, "lifetime": "detached"} + try: + memory = parse_memory_size(config.memory) + actor_options["num_cpus"] = config.cpus + actor_options["memory"] = memory + return actor_options + except ValueError as e: + logger.warning(f"Invalid memory size: {config.memory}", exc_info=e) + raise BadRequestRockError(f"Invalid memory size: {config.memory}") + + async def submit(self, config: DockerDeploymentConfig, user_info: dict = {}) -> SandboxInfo: + async with self._ray_service.get_ray_rwlock().read_lock(): + sandbox_id = config.container_name + logger.info(f"[{sandbox_id}] start_async params:{json.dumps(config.model_dump(), indent=2)}") + sandbox_actor: SandboxActor = await self.create_actor(config) + sandbox_actor.start.remote() + user_id = user_info.get("user_id", "default") + experiment_id = user_info.get("experiment_id", "default") + namespace = user_info.get("namespace", "default") + rock_authorization = user_info.get("rock_authorization", "default") + sandbox_actor.set_user_id.remote(user_id) + sandbox_actor.set_experiment_id.remote(experiment_id) + sandbox_actor.set_namespace.remote(namespace) + sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) + sandbox_info["user_id"] = user_id + sandbox_info["experiment_id"] = experiment_id + sandbox_info["namespace"] = namespace + sandbox_info["state"] = State.PENDING + sandbox_info["rock_authorization"] = rock_authorization + logger.info(f"sandbox {sandbox_id} is submitted") + return sandbox_info async def get_status(self, sandbox_id: str) -> SandboxInfo: - return SandboxInfo(sandbox_id="test", host_name="test", host_ip="test") + async with self._ray_service.get_ray_rwlock().read_lock(): + actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id)) + sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(actor.sandbox_info.remote()) + remote_status: ServiceStatus = await self._ray_service.async_ray_get(actor.get_status.remote()) + sandbox_info["phases"] = remote_status.phases + sandbox_info["port_mapping"] = remote_status.get_port_mapping() + alive = await self._ray_service.async_ray_get(actor.is_alive.remote()) + if alive.is_alive: + sandbox_info["state"] = State.RUNNING + if not self._redis_provider: + return sandbox_info + redis_info = await self.get_sandbox_info_from_redis(sandbox_id) + if redis_info: + redis_info.update(sandbox_info) + redis_info["phases"] = {name: phase.to_dict() for name, phase in remote_status.phases.items()} + return redis_info + else: + return sandbox_info + # return sandbox_info + + async def get_sandbox_info_from_redis(self, sandbox_id: str) -> SandboxInfo: + sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") + if sandbox_status and len(sandbox_status) > 0: + sandbox_info = sandbox_status[0] + return sandbox_info + return None async def stop(self, sandbox_id: str) -> bool: - return True + async with self._ray_service.get_ray_rwlock().read_lock(): + actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id)) + await self._ray_service.async_ray_get(actor.stop.remote()) + logger.info(f"run time stop over {sandbox_id}") + ray.kill(actor) + return True diff --git a/rock/sandbox/sandbox_manager.py b/rock/sandbox/sandbox_manager.py index b26092320..2a58864e8 100644 --- a/rock/sandbox/sandbox_manager.py +++ b/rock/sandbox/sandbox_manager.py @@ -1,8 +1,6 @@ import asyncio -import json import time -import ray from fastapi import UploadFile from rock import env_vars @@ -73,6 +71,8 @@ def __init__( self._ray_namespace = ray_namespace self._operator = operator self._aes_encrypter = AESEncryption() + if redis_provider: + self._operator.set_redis_provider(redis_provider) logger.info("sandbox service init success") async def refresh_aes_key(self): @@ -84,7 +84,6 @@ async def refresh_aes_key(self): logger.error(f"update aes key failed, error: {e}") raise InternalServerRockError(f"update aes key failed, {str(e)}") - async def _check_sandbox_exists_in_redis(self, config: DeploymentConfig): if isinstance(config, DockerDeploymentConfig) and config.container_name: sandbox_id = config.container_name @@ -118,39 +117,25 @@ async def _build_sandbox_info_metadata( async def start_async( self, config: DeploymentConfig, user_info: UserInfo = {}, cluster_info: ClusterInfo = {} ) -> SandboxStartResponse: - async with self._ray_service.get_ray_rwlock().read_lock(): - await self._check_sandbox_exists_in_redis(config) - docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config) - sandbox_id = docker_deployment_config.container_name - logger.info( - f"[{sandbox_id}] start_async params:{json.dumps(docker_deployment_config.model_dump(), indent=2)}" - ) - actor_name = self.deployment_manager.get_actor_name(sandbox_id) - - deployment = docker_deployment_config.get_deployment() - - self.validate_sandbox_spec(self.rock_config.runtime, config) - sandbox_actor: SandboxActor = await deployment.creator_actor(actor_name) - sandbox_actor.start.remote() - self._setup_sandbox_actor_metadata(sandbox_actor, user_info) - - self._sandbox_meta[sandbox_id] = {"image": docker_deployment_config.image} - logger.info(f"sandbox {sandbox_id} is submitted") - stop_time = str(int(time.time()) + docker_deployment_config.auto_clear_time * 60) - auto_clear_time_dict = { - env_vars.ROCK_SANDBOX_AUTO_CLEAR_TIME_KEY: str(docker_deployment_config.auto_clear_time), - env_vars.ROCK_SANDBOX_EXPIRE_TIME_KEY: stop_time, - } - sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) - await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info) - if self._redis_provider: - await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) - await self._redis_provider.json_set(timeout_sandbox_key(sandbox_id), "$", auto_clear_time_dict) - return SandboxStartResponse( - sandbox_id=sandbox_id, - host_name=sandbox_info.get("host_name"), - host_ip=sandbox_info.get("host_ip"), - ) + await self._check_sandbox_exists_in_redis(config) + self.validate_sandbox_spec(self.rock_config.runtime, config) + docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config) + sandbox_id = docker_deployment_config.container_name + sandbox_info: SandboxInfo = await self._operator.submit(docker_deployment_config, user_info) + stop_time = str(int(time.time()) + docker_deployment_config.auto_clear_time * 60) + auto_clear_time_dict = { + env_vars.ROCK_SANDBOX_AUTO_CLEAR_TIME_KEY: str(docker_deployment_config.auto_clear_time), + env_vars.ROCK_SANDBOX_EXPIRE_TIME_KEY: stop_time, + } + await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info) + if self._redis_provider: + await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) + await self._redis_provider.json_set(timeout_sandbox_key(sandbox_id), "$", auto_clear_time_dict) + return SandboxStartResponse( + sandbox_id=sandbox_id, + host_name=sandbox_info.get("host_name"), + host_ip=sandbox_info.get("host_ip"), + ) @monitor_sandbox_operation() async def start(self, config: DeploymentConfig) -> SandboxStartResponse: @@ -181,28 +166,22 @@ async def start(self, config: DeploymentConfig) -> SandboxStartResponse: @monitor_sandbox_operation() async def stop(self, sandbox_id): - async with self._ray_service.get_ray_rwlock().read_lock(): - logger.info(f"stop sandbox {sandbox_id}") - sandbox_info: SandboxInfo = await build_sandbox_from_redis(self._redis_provider, sandbox_id) - if sandbox_info and sandbox_info.get("start_time"): - sandbox_info["stop_time"] = get_iso8601_timestamp() - log_billing_info(sandbox_info=sandbox_info) - try: - actor_name = self.deployment_manager.get_actor_name(sandbox_id) - sandbox_actor = await self._ray_service.async_ray_get_actor(actor_name, self._ray_namespace) - except ValueError as e: - await self._clear_redis_keys(sandbox_id) - raise Exception(f"sandbox {sandbox_id} not found to stop, {str(e)}") - logger.info(f"start to stop run time {sandbox_id}") - await self._ray_service.async_ray_get(sandbox_actor.stop.remote()) - logger.info(f"run time stop over {sandbox_id}") - ray.kill(sandbox_actor) - try: - self._sandbox_meta.pop(sandbox_id) - except KeyError: - logger.debug(f"{sandbox_id} key not found") - logger.info(f"sandbox {sandbox_id} stopped") + logger.info(f"stop sandbox {sandbox_id}") + sandbox_info: SandboxInfo = await build_sandbox_from_redis(self._redis_provider, sandbox_id) + if sandbox_info and sandbox_info.get("start_time"): + sandbox_info["stop_time"] = get_iso8601_timestamp() + log_billing_info(sandbox_info=sandbox_info) + try: + await self._operator.stop(sandbox_id) + except ValueError as e: + logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e) await self._clear_redis_keys(sandbox_id) + try: + self._sandbox_meta.pop(sandbox_id) + except KeyError: + logger.debug(f"{sandbox_id} key not found") + logger.info(f"sandbox {sandbox_id} stopped") + await self._clear_redis_keys(sandbox_id) async def get_mount(self, sandbox_id): async with self._ray_service.get_ray_rwlock().read_lock(): @@ -236,61 +215,50 @@ async def _clear_redis_keys(self, sandbox_id): logger.info(f"sandbox {sandbox_id} deleted from redis") @monitor_sandbox_operation() - async def get_status(self, sandbox_id) -> SandboxStatusResponse: - async with self._ray_service.get_ray_rwlock().read_lock(): - actor_name = self.deployment_manager.get_actor_name(sandbox_id) - sandbox_actor = await self._ray_service.async_ray_get_actor(actor_name, self._ray_namespace) - if sandbox_actor is None: - raise Exception(f"sandbox {sandbox_id} not found to get status") - else: - remote_status: ServiceStatus = await self._ray_service.async_ray_get(sandbox_actor.get_status.remote()) - alive = await self._ray_service.async_ray_get(sandbox_actor.is_alive.remote()) - sandbox_info: SandboxInfo = None - if self._redis_provider: - sandbox_info = await build_sandbox_from_redis(self._redis_provider, sandbox_id) - if sandbox_info is None: - # The start() method will write to redis on the first call to get_status() - sandbox_info = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) - sandbox_info.update(remote_status.to_dict()) - self._update_sandbox_alive_info(sandbox_info, alive.is_alive) - await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) - await self._update_expire_time(sandbox_id) - logger.info(f"sandbox {sandbox_id} status is {sandbox_info}, write to redis") - else: - sandbox_info = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) - - return SandboxStatusResponse( - sandbox_id=sandbox_id, - status=remote_status.phases, - state=sandbox_info.get("state"), - port_mapping=remote_status.get_port_mapping(), - host_name=sandbox_info.get("host_name"), - host_ip=sandbox_info.get("host_ip"), - is_alive=alive.is_alive, - image=sandbox_info.get("image"), - swe_rex_version=swe_version, - gateway_version=gateway_version, - user_id=sandbox_info.get("user_id"), - experiment_id=sandbox_info.get("experiment_id"), - namespace=sandbox_info.get("namespace"), - cpus=sandbox_info.get("cpus"), - memory=sandbox_info.get("memory"), - ) - - async def _get_sandbox_info(self, sandbox_id: str) -> SandboxInfo: - """Get sandbox info, prioritize Redis, fallback to Ray Actor""" - if self._redis_provider: - sandbox_info = await build_sandbox_from_redis(self._redis_provider, sandbox_id) + async def get_status(self, sandbox_id, use_rocklet: bool = False) -> SandboxStatusResponse: + if use_rocklet and self._redis_provider: + sandbox_info: SandboxInfo = await build_sandbox_from_redis(self._redis_provider, sandbox_id) + host_ip = sandbox_info.get("host_ip") + remote_status = await self.get_remote_status(sandbox_id, host_ip) + is_alive = await self._check_alive_status(sandbox_id, host_ip, remote_status) + sandbox_info.update(remote_status.to_dict()) else: - actor_name = self.deployment_manager.get_actor_name(sandbox_id) - sandbox_actor = await self._ray_service.async_ray_get_actor(actor_name, self._ray_namespace) - if sandbox_actor is None: - raise Exception(f"sandbox {sandbox_id} not found to get status") - sandbox_info = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) - - if sandbox_info is None: - raise Exception(f"sandbox {sandbox_id} not found to get status") + sandbox_info: SandboxInfo = await self._operator.get_status(sandbox_id=sandbox_id) + is_alive = sandbox_info.get("state") == State.RUNNING + self._update_sandbox_alive_info(sandbox_info, is_alive) + if self._redis_provider: + await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) + (self._update_expire_time(sandbox_id),) + return SandboxStatusResponse( + sandbox_id=sandbox_id, + status=sandbox_info.get("phases"), + port_mapping=sandbox_info.get("port_mapping"), + state=sandbox_info.get("state"), + host_name=sandbox_info.get("host_name"), + host_ip=sandbox_info.get("host_ip"), + is_alive=is_alive, + image=sandbox_info.get("image"), + swe_rex_version=swe_version, + gateway_version=gateway_version, + user_id=sandbox_info.get("user_id"), + experiment_id=sandbox_info.get("experiment_id"), + namespace=sandbox_info.get("namespace"), + cpus=sandbox_info.get("cpus"), + memory=sandbox_info.get("memory"), + ) + async def build_sandbox_info_from_redis(self, sandbox_id: str, deployment_info: SandboxInfo) -> SandboxInfo | None: + sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") + if sandbox_status and len(sandbox_status) > 0: + sandbox_info = sandbox_status[0] + remote_info = { + k: v for k, v in deployment_info.items() if k in ["phases", "port_mapping", "alive", "state"] + } + if "phases" in remote_info and remote_info["phases"]: + remote_info["phases"] = {name: phase.to_dict() for name, phase in remote_info["phases"].items()} + sandbox_info.update(remote_info) + else: + sandbox_info = deployment_info return sandbox_info async def _check_alive_status(self, sandbox_id: str, host_ip: str, remote_status: ServiceStatus) -> bool: @@ -314,46 +282,12 @@ def _update_sandbox_alive_info(self, sandbox_info: SandboxInfo, is_alive: bool) if sandbox_info.get("start_time") is None: sandbox_info["start_time"] = get_iso8601_timestamp() - @monitor_sandbox_operation() async def get_status_v2(self, sandbox_id) -> SandboxStatusResponse: - # 1. Get sandbox_info (unified exception handling) - sandbox_info = await self._get_sandbox_info(sandbox_id) - - # 2. Parallel execution: update expire time & get remote status - host_ip = sandbox_info.get("host_ip") - _, remote_status = await asyncio.gather( - self._update_expire_time(sandbox_id), - self.get_remote_status(sandbox_id, host_ip), - ) - - # 3. Update sandbox_info and check alive status - sandbox_info.update(remote_status.to_dict()) - is_alive = await self._check_alive_status(sandbox_id, host_ip, remote_status) - self._update_sandbox_alive_info(sandbox_info, is_alive) - - # 4. Persist to Redis if Redis exists - if self._redis_provider: - await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) - logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis") - - # 5. Build and return response - return SandboxStatusResponse( - sandbox_id=sandbox_id, - status=remote_status.phases, - port_mapping=remote_status.get_port_mapping(), - state=sandbox_info.get("state"), - host_name=sandbox_info.get("host_name"), - host_ip=sandbox_info.get("host_ip"), - is_alive=is_alive, - image=sandbox_info.get("image"), - swe_rex_version=swe_version, - gateway_version=gateway_version, - user_id=sandbox_info.get("user_id"), - experiment_id=sandbox_info.get("experiment_id"), - namespace=sandbox_info.get("namespace"), - cpus=sandbox_info.get("cpus"), - memory=sandbox_info.get("memory"), - ) + """ + Deprecated: Use get_status(sandbox_id, use_rocklet=True) instead. + This method is kept for backward compatibility. + """ + return await self.get_status(sandbox_id, use_rocklet=True) async def get_remote_status(self, sandbox_id: str, host_ip: str) -> ServiceStatus: service_status_path = PersistedServiceStatus.gen_service_status_path(sandbox_id) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2fae04594..2dc33e37a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,13 +7,14 @@ from fakeredis import aioredis from ray.util.state import list_actors +from rock.admin.core.ray_service import RayService from rock.config import RockConfig from rock.deployments.config import DockerDeploymentConfig from rock.logger import init_logger +from rock.sandbox.operator.ray import RayOperator from rock.sandbox.sandbox_manager import SandboxManager from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService from rock.utils.providers.redis_provider import RedisProvider -from rock.admin.core.ray_service import RayService logger = init_logger(__name__) @@ -37,19 +38,30 @@ async def redis_provider(): yield provider await provider.close_pool() + @pytest.fixture def ray_service(rock_config: RockConfig, ray_init_shutdown): ray_service = RayService(rock_config.ray) return ray_service + +@pytest.fixture +def ray_operator(ray_service): + ray_operator = RayOperator(ray_service) + return ray_operator + + @pytest.fixture -async def sandbox_manager(rock_config: RockConfig, redis_provider: RedisProvider, ray_init_shutdown, ray_service): +async def sandbox_manager( + rock_config: RockConfig, redis_provider: RedisProvider, ray_init_shutdown, ray_service, ray_operator +): sandbox_manager = SandboxManager( rock_config, redis_provider=redis_provider, ray_namespace=rock_config.ray.namespace, ray_service=ray_service, enable_runtime_auto_clear=rock_config.runtime.enable_auto_clear, + operator=ray_operator, ) return sandbox_manager diff --git a/tests/unit/sandbox/operator/test_ray_operator.py b/tests/unit/sandbox/operator/test_ray_operator.py index 209cd1b3b..9c8fa322c 100644 --- a/tests/unit/sandbox/operator/test_ray_operator.py +++ b/tests/unit/sandbox/operator/test_ray_operator.py @@ -6,17 +6,17 @@ @pytest.mark.asyncio -async def test_ray_operator(): - operator = RayOperator() - start_response: SandboxInfo = await operator.submit(DockerDeploymentConfig()) +async def test_ray_operator(ray_service): + operator = RayOperator(ray_service=ray_service) + start_response: SandboxInfo = await operator.submit(DockerDeploymentConfig(container_name="test")) assert start_response.get("sandbox_id") == "test" - assert start_response.get("host_name") == "test" - assert start_response.get("host_ip") == "test" - - stop_response: bool = await operator.stop("test") - assert stop_response + assert start_response.get("host_name") is not None + assert start_response.get("host_ip") is not None status_response: SandboxInfo = await operator.get_status("test") assert status_response.get("sandbox_id") == "test" - assert status_response.get("host_name") == "test" - assert status_response.get("host_ip") == "test" \ No newline at end of file + assert start_response.get("host_name") is not None + assert start_response.get("host_ip") is not None + + stop_response: bool = await operator.stop("test") + assert stop_response