Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions rock/admin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from rock.config import RockConfig
from rock.logger import init_logger
from rock.sandbox.gem_manager import GemManager
from rock.sandbox.operator.factory import OperatorContext, OperatorFactory
from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService
from rock.sandbox.service.warmup_service import WarmupService
from rock.utils import EAGLE_EYE_TRACE_ID, sandbox_id_ctx_var, trace_id_ctx_var
Expand Down Expand Up @@ -70,6 +71,13 @@ async def lifespan(app: FastAPI):
ray_service = RayService(rock_config.ray)
ray_service.init()

# create operator using factory with context pattern
operator_context = OperatorContext(
runtime_config=rock_config.runtime,
ray_service=ray_service,
)
operator = OperatorFactory.create_operator(operator_context)

# init service
if rock_config.runtime.enable_auto_clear:
sandbox_manager = GemManager(
Expand All @@ -78,6 +86,7 @@ async def lifespan(app: FastAPI):
ray_namespace=rock_config.ray.namespace,
ray_service=ray_service,
enable_runtime_auto_clear=True,
operator=operator,
)
else:
sandbox_manager = GemManager(
Expand All @@ -86,6 +95,7 @@ async def lifespan(app: FastAPI):
ray_namespace=rock_config.ray.namespace,
ray_service=ray_service,
enable_runtime_auto_clear=False,
operator=operator,
)
set_sandbox_manager(sandbox_manager)
warmup_service = WarmupService(rock_config.warmup)
Expand Down
1 change: 1 addition & 0 deletions rock/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class RuntimeConfig:
project_root: str = field(default_factory=lambda: env_vars.ROCK_PROJECT_ROOT)
python_env_path: str = field(default_factory=lambda: env_vars.ROCK_PYTHON_ENV_PATH)
envhub_db_url: str = field(default_factory=lambda: env_vars.ROCK_ENVHUB_DB_URL)
operator_type: str = "ray"
standard_spec: StandardSpec = field(default_factory=StandardSpec)
max_allowed_spec: StandardSpec = field(default_factory=lambda: StandardSpec(cpus=16, memory="64g"))

Expand Down
4 changes: 2 additions & 2 deletions rock/sandbox/gem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from rock.sandbox.sandbox_actor import SandboxActor
from rock.sandbox.sandbox_manager import SandboxManager
from rock.utils.providers import RedisProvider
from rock.admin.core.ray_service import RayService


class GemManager(SandboxManager):
Expand All @@ -30,8 +29,9 @@ def __init__(
ray_namespace: str = env_vars.ROCK_RAY_NAMESPACE,
ray_service: RayService | None = None,
enable_runtime_auto_clear: bool = False,
operator=None,
):
super().__init__(rock_config, redis_provider, ray_namespace, ray_service, enable_runtime_auto_clear)
super().__init__(rock_config, redis_provider, ray_namespace, ray_service, enable_runtime_auto_clear, operator)

async def env_make(self, env_id: str) -> EnvMakeResponse:
config = DockerDeploymentConfig(image=env_vars.ROCK_ENVHUB_DEFAULT_DOCKER_IMAGE)
Expand Down
10 changes: 8 additions & 2 deletions rock/sandbox/operator/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from rock.actions.sandbox.sandbox_info import SandboxInfo
from rock.deployments.config import DeploymentConfig
from rock.utils.providers.redis_provider import RedisProvider


class AbstractOperator(ABC):
_redis_provider: RedisProvider | None = None

@abstractmethod
async def submit(self, config: DeploymentConfig) -> SandboxInfo:
async def submit(self, config: DeploymentConfig, user_info: dict = {}) -> SandboxInfo:
...

@abstractmethod
Expand All @@ -15,4 +18,7 @@ async def get_status(self, sandbox_id: str) -> SandboxInfo:

@abstractmethod
async def stop(self, sandbox_id: str) -> bool:
...
...

def set_redis_provider(self, redis_provider: RedisProvider):
self._redis_provider = redis_provider
60 changes: 60 additions & 0 deletions rock/sandbox/operator/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Operator factory for creating operator instances based on configuration."""

from dataclasses import dataclass, field
from typing import Any

from rock.admin.core.ray_service import RayService
from rock.config import RuntimeConfig
from rock.logger import init_logger
from rock.sandbox.operator.abstract import AbstractOperator
from rock.sandbox.operator.ray import RayOperator

logger = init_logger(__name__)


@dataclass
class OperatorContext:
"""Context object containing all dependencies needed for operator creation.

This design pattern solves the parameter explosion problem by encapsulating
all dependencies in a single context object. New operator types can add their
dependencies to this context without changing the factory method signature.
"""

runtime_config: RuntimeConfig
ray_service: RayService | None = None
# Future operator dependencies can be added here without breaking existing code
# kubernetes_client: Any | None = None
# docker_client: Any | None = None
extra_params: dict[str, Any] = field(default_factory=dict)


class OperatorFactory:
"""Factory class for creating operator instances.

Uses the Context Object pattern to avoid parameter explosion as new
operator types are added.
"""

@staticmethod
def create_operator(context: OperatorContext) -> AbstractOperator:
"""Create an operator instance based on the runtime configuration.

Args:
context: OperatorContext containing all necessary dependencies

Returns:
AbstractOperator: The created operator instance

Raises:
ValueError: If operator_type is not supported or required dependencies are missing
"""
operator_type = context.runtime_config.operator_type.lower()

if operator_type == "ray":
if context.ray_service is None:
raise ValueError("RayService is required for RayOperator")
logger.info("Creating RayOperator")
return RayOperator(ray_service=context.ray_service)
else:
raise ValueError(f"Unsupported operator type: {operator_type}. " f"Supported types: ray, kubernetes")
98 changes: 93 additions & 5 deletions rock/sandbox/operator/ray.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,102 @@
import json

import ray

from rock.actions.sandbox.response import State
from rock.actions.sandbox.sandbox_info import SandboxInfo
from rock.deployments.config import DeploymentConfig
from rock.admin.core.ray_service import RayService
from rock.admin.core.redis_key import alive_sandbox_key
from rock.deployments.config import DockerDeploymentConfig
from rock.deployments.docker import DockerDeployment
from rock.deployments.status import ServiceStatus
from rock.logger import init_logger
from rock.sandbox.operator.abstract import AbstractOperator
from rock.sandbox.sandbox_actor import SandboxActor
from rock.sdk.common.exceptions import BadRequestRockError
from rock.utils.format import parse_memory_size

logger = init_logger(__name__)


class RayOperator(AbstractOperator):
async def submit(self, config: DeploymentConfig) -> SandboxInfo:
return SandboxInfo(sandbox_id="test", host_name="test", host_ip="test")
def __init__(self, ray_service: RayService):
self._ray_service = ray_service

def _get_actor_name(self, sandbox_id: str) -> str:
return f"sandbox-{sandbox_id}"

async def create_actor(self, config: DockerDeploymentConfig):
actor_options = self._generate_actor_options(config)
deployment: DockerDeployment = config.get_deployment()
sandbox_actor = SandboxActor.options(**actor_options).remote(config, deployment)
return sandbox_actor

def _generate_actor_options(self, config: DockerDeploymentConfig) -> dict:
actor_name = self._get_actor_name(config.container_name)
actor_options = {"name": actor_name, "lifetime": "detached"}
try:
memory = parse_memory_size(config.memory)
actor_options["num_cpus"] = config.cpus
actor_options["memory"] = memory
return actor_options
except ValueError as e:
logger.warning(f"Invalid memory size: {config.memory}", exc_info=e)
raise BadRequestRockError(f"Invalid memory size: {config.memory}")

async def submit(self, config: DockerDeploymentConfig, user_info: dict = {}) -> SandboxInfo:
async with self._ray_service.get_ray_rwlock().read_lock():
sandbox_id = config.container_name
logger.info(f"[{sandbox_id}] start_async params:{json.dumps(config.model_dump(), indent=2)}")
sandbox_actor: SandboxActor = await self.create_actor(config)
sandbox_actor.start.remote()
user_id = user_info.get("user_id", "default")
experiment_id = user_info.get("experiment_id", "default")
namespace = user_info.get("namespace", "default")
rock_authorization = user_info.get("rock_authorization", "default")
sandbox_actor.set_user_id.remote(user_id)
sandbox_actor.set_experiment_id.remote(experiment_id)
sandbox_actor.set_namespace.remote(namespace)
sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote())
sandbox_info["user_id"] = user_id
sandbox_info["experiment_id"] = experiment_id
sandbox_info["namespace"] = namespace
sandbox_info["state"] = State.PENDING
sandbox_info["rock_authorization"] = rock_authorization
logger.info(f"sandbox {sandbox_id} is submitted")
return sandbox_info

async def get_status(self, sandbox_id: str) -> SandboxInfo:
return SandboxInfo(sandbox_id="test", host_name="test", host_ip="test")
async with self._ray_service.get_ray_rwlock().read_lock():
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(actor.sandbox_info.remote())
remote_status: ServiceStatus = await self._ray_service.async_ray_get(actor.get_status.remote())
sandbox_info["phases"] = remote_status.phases
sandbox_info["port_mapping"] = remote_status.get_port_mapping()
alive = await self._ray_service.async_ray_get(actor.is_alive.remote())
if alive.is_alive:
sandbox_info["state"] = State.RUNNING
if not self._redis_provider:
return sandbox_info
redis_info = await self.get_sandbox_info_from_redis(sandbox_id)
if redis_info:
redis_info.update(sandbox_info)
redis_info["phases"] = {name: phase.to_dict() for name, phase in remote_status.phases.items()}
return redis_info
else:
return sandbox_info
# return sandbox_info

async def get_sandbox_info_from_redis(self, sandbox_id: str) -> SandboxInfo:
sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$")
if sandbox_status and len(sandbox_status) > 0:
sandbox_info = sandbox_status[0]
return sandbox_info
return None

async def stop(self, sandbox_id: str) -> bool:
return True
async with self._ray_service.get_ray_rwlock().read_lock():
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
await self._ray_service.async_ray_get(actor.stop.remote())
logger.info(f"run time stop over {sandbox_id}")
ray.kill(actor)
return True
Loading
Loading