Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docs/source/guides/execution.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,23 @@ def your_skypilot_executor(nodes: int, devices: int, container_image: str):
return SkypilotExecutor(
gpus="RTX5880-ADA-GENERATION",
gpus_per_node=devices,
nodes = nodes
env_vars=common_envs()
num_nodes = nodes,
env_vars=common_envs(),
container_image=container_image,
cloud="kubernetes",
infra="k8s/mycontext",
# Optional to reuse Skypilot cluster
cluster_name="tester",
volumes={
"nemo-workspace": "nemo-workspace"
},
volume_mounts=[
{
"path": "/data",
"volume_name": "nemo-workspace",
"size": "50Gi",
"type": "k8s-pvc"
}
],
setup="""
conda deactivate
nvidia-smi
Expand Down
91 changes: 90 additions & 1 deletion nemo_run/core/execution/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

from invoke.context import Context

Expand All @@ -36,6 +36,8 @@
import sky.task as skyt
from sky import backends
from sky.utils import status_lib
from sky.volumes import volume as volume_lib
from sky import models

_SKYPILOT_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -94,6 +96,8 @@ class SkypilotExecutor(Executor):
memory: Optional[Union[int | float, list[int | float]]] = None
instance_type: Optional[Union[str, list[str]]] = None
num_nodes: int = 1
volumes: Optional[Dict[str, str]] = None
volume_mounts: Optional[List[Any]] = None
use_spot: Optional[Union[bool, list[bool]]] = None
disk_size: Optional[Union[int, list[int]]] = None
disk_tier: Optional[Union[str, list[str]]] = None
Expand Down Expand Up @@ -341,6 +345,64 @@ def macro_values(self) -> Optional[ExecutorMacros]:
het_group_host_var=self.HET_GROUP_HOST_VAR,
)

def _setup_launcher(self):
# Auto-enable torchrun for distributed training scenarios:
# 1. Multi-node training (num_nodes > 1)
# 2. Single-node multi-GPU training (gpus_per_node > 1)
if self.launcher is None and (
self.num_nodes > 1 or (self.gpus_per_node and self.gpus_per_node > 1)
):
self.launcher = "torchrun"

super()._setup_launcher()

def supports_launcher_transform(self) -> bool:
return True

def _parse_infra_for_volume_config(self) -> dict:
"""Parse infra string and return volume config parameters."""
config = {}

if self.infra is not None:
# Parse infra string to extract cloud, region, zone components
# Format: cloud, cloud/region, cloud/region/zone, k8s/context
infra_parts = self.infra.split("/")
cloud = infra_parts[0] if infra_parts else None

if cloud:
# Special handling for Kubernetes
if cloud == "k8s":
# VolumeConfig region and zone required even though they are marked as optional
# validation fails otherwise
config["cloud"] = "kubernetes"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this value needs to be kubernetes - based on the providers supported list in skypilot

config["region"] = "kubernetes"
config["zone"] = "kubernetes"
else:
# Handle regular cloud providers
config["cloud"] = cloud

# Handle region for non-k8s clouds
if len(infra_parts) >= 2:
region = infra_parts[1]
if region and region != "*": # Skip wildcards
config["region"] = region

# Handle zone for non-k8s clouds
if len(infra_parts) >= 3:
zone = infra_parts[2]
if zone and zone != "*": # Skip wildcards
config["zone"] = zone
else:
# Fall back to individual cloud, region, zone parameters
if self.cloud:
config["cloud"] = self.cloud
if self.region:
config["region"] = self.region
if self.zone:
config["zone"] = self.zone

return config

def to_task(
self,
name: str,
Expand All @@ -364,16 +426,43 @@ def to_task(

{" ".join(cmd)}
"""

task = Task(
name=name,
setup=self.setup if self.setup else "",
run=run_cmd,
envs=self.env_vars,
num_nodes=self.num_nodes,
volumes=self.volumes,
)

file_mounts = self.file_mounts or {}
file_mounts["/nemo_run"] = self.job_dir
task.set_file_mounts(file_mounts)
task.set_volumes(self.volumes)

volume_mounts = []
if self.volume_mounts:
for volume_mount in self.volume_mounts:
# Configure volume based on infra if specified, otherwise use cloud/region/zone
volume_config_kwargs = {
"name": volume_mount["volume_name"],
"type": volume_mount["type"],
"name_on_cloud": volume_mount["volume_name"],
"size": volume_mount["size"],
}

# Add parsed infra configuration
volume_config_kwargs.update(self._parse_infra_for_volume_config())

volume_mounts.append(
volume_lib.VolumeMount(
path=volume_mount["path"],
volume_name=volume_mount["volume_name"],
volume_config=models.VolumeConfig(**volume_config_kwargs),
)
)
task.volume_mounts = volume_mounts
task.set_resources(self.to_resources())

if env_vars:
Expand Down
230 changes: 230 additions & 0 deletions test/core/execution/test_skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,233 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor):

# Verify the returned task is our mock
assert result == mock_task_instance

def test_parse_infra_for_volume_config(self, mock_skypilot_imports):
"""Test the _parse_infra_for_volume_config helper method."""

# Test k8s infra
executor_k8s = SkypilotExecutor(infra="k8s/my-context")
config = executor_k8s._parse_infra_for_volume_config()
assert config["cloud"] == "kubernetes"
assert config["region"] == "kubernetes"
assert config["zone"] == "kubernetes"

# Test AWS infra with region and zone
executor_aws = SkypilotExecutor(infra="aws/us-east-1/us-east-1a")
config = executor_aws._parse_infra_for_volume_config()
assert config["cloud"] == "aws"
assert config["region"] == "us-east-1"
assert config["zone"] == "us-east-1a"

# Test fallback to individual parameters
executor_fallback = SkypilotExecutor(
cloud="gcp", region="us-central1", zone="us-central1-a"
)
config = executor_fallback._parse_infra_for_volume_config()
assert config["cloud"] == "gcp"
assert config["region"] == "us-central1"
assert config["zone"] == "us-central1-a"

def test_volume_mounts_initialization(self, mock_skypilot_imports):
"""Test that volume_mounts are properly stored during initialization."""
volume_mounts = [
{"path": "/data", "volume_name": "nemo-workspace", "size": "50Gi", "type": "k8s-pvc"}
]

executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
cloud="kubernetes",
cluster_name="test-cluster",
volume_mounts=volume_mounts,
)

# Verify volume_mounts are stored correctly
assert executor.volume_mounts == volume_mounts
assert len(executor.volume_mounts) == 1
assert executor.volume_mounts[0]["path"] == "/data"
assert executor.volume_mounts[0]["volume_name"] == "nemo-workspace"

def test_volume_mounts_none(self, mock_skypilot_imports):
"""Test that volume_mounts can be None."""
executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
cloud="kubernetes",
cluster_name="test-cluster",
volume_mounts=None,
)

# Verify volume_mounts is None
assert executor.volume_mounts is None

@patch("sky.task.Task")
@patch("sky.volumes.volume.VolumeMount")
@patch("sky.models.VolumeConfig")
def test_volume_mounts_to_task_processing(
self, mock_volume_config, mock_volume_mount, mock_task, mock_skypilot_imports
):
"""Test that volume_mounts are processed in to_task method."""
mock_task_instance = MagicMock()
mock_task.return_value = mock_task_instance

volume_mounts = [
{"path": "/data", "volume_name": "nemo-workspace", "size": "50Gi", "type": "k8s-pvc"}
]

executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
cloud="kubernetes",
region="kubernetes",
zone="kubernetes",
cluster_name="test-cluster",
volume_mounts=volume_mounts,
)

with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
mock_to_resources.return_value = MagicMock()

with tempfile.TemporaryDirectory() as tmp_dir:
executor.job_dir = tmp_dir
executor.to_task("test_task", ["python", "script.py"])

# Verify volume processing was called (covers the processing logic)
mock_volume_mount.assert_called_once()
mock_volume_config.assert_called_once()

@patch("sky.task.Task")
@patch("sky.volumes.volume.VolumeMount")
@patch("sky.models.VolumeConfig")
def test_volume_mounts_with_infra(
self, mock_volume_config, mock_volume_mount, mock_task, mock_skypilot_imports
):
"""Test volume_mounts processing when using infra instead of cloud/region/zone."""
mock_task_instance = MagicMock()
mock_task.return_value = mock_task_instance

volume_mounts = [
{"path": "/data", "volume_name": "nemo-workspace", "size": "50Gi", "type": "k8s-pvc"}
]

executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
infra="k8s/kubernetes-admin@kubernetes",
cluster_name="test-cluster",
volume_mounts=volume_mounts,
)

with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
mock_to_resources.return_value = MagicMock()

with tempfile.TemporaryDirectory() as tmp_dir:
executor.job_dir = tmp_dir
executor.to_task("test_task", ["python", "script.py"])

# Verify volume processing was called (covers the infra parsing logic)
mock_volume_mount.assert_called_once()
mock_volume_config.assert_called_once()

@patch("sky.task.Task")
@patch("sky.volumes.volume.VolumeMount")
@patch("sky.models.VolumeConfig")
def test_volume_mounts_with_aws_infra(
self, mock_volume_config, mock_volume_mount, mock_task, mock_skypilot_imports
):
"""Test volume_mounts with AWS infra format (cloud/region/zone)."""
mock_task_instance = MagicMock()
mock_task.return_value = mock_task_instance

volume_mounts = [
{"path": "/data", "volume_name": "test-vol", "size": "10Gi", "type": "gp2"}
]

executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
infra="aws/us-east-1/us-east-1a",
volume_mounts=volume_mounts,
)

with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
mock_to_resources.return_value = MagicMock()

with tempfile.TemporaryDirectory() as tmp_dir:
executor.job_dir = tmp_dir
executor.to_task("test_task", ["python", "script.py"])

# Verify AWS infra parsing was covered
mock_volume_mount.assert_called_once()
mock_volume_config.assert_called_once()

@patch("sky.task.Task")
@patch("sky.volumes.volume.VolumeMount")
@patch("sky.models.VolumeConfig")
def test_volume_mounts_with_gcp_infra_wildcard(
self, mock_volume_config, mock_volume_mount, mock_task, mock_skypilot_imports
):
"""Test volume_mounts with GCP infra including wildcard zone."""
mock_task_instance = MagicMock()
mock_task.return_value = mock_task_instance

volume_mounts = [
{"path": "/data", "volume_name": "test-vol", "size": "10Gi", "type": "pd-ssd"}
]

executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
infra="gcp/us-central1/*", # Wildcard zone should be skipped
volume_mounts=volume_mounts,
)

with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
mock_to_resources.return_value = MagicMock()

with tempfile.TemporaryDirectory() as tmp_dir:
executor.job_dir = tmp_dir
executor.to_task("test_task", ["python", "script.py"])

# Verify wildcard zone handling was covered
mock_volume_mount.assert_called_once()
mock_volume_config.assert_called_once()

@patch("sky.task.Task")
@patch("sky.volumes.volume.VolumeMount")
@patch("sky.models.VolumeConfig")
def test_volume_mounts_fallback_individual_params(
self, mock_volume_config, mock_volume_mount, mock_task, mock_skypilot_imports
):
"""Test volume_mounts fallback to individual cloud/region/zone params when infra is None."""
mock_task_instance = MagicMock()
mock_task.return_value = mock_task_instance

volume_mounts = [
{"path": "/data", "volume_name": "test-vol", "size": "10Gi", "type": "gp2"}
]

# Test with individual parameters (no infra) - covers fallback path
executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
cloud="aws",
region="us-west-2",
zone="us-west-2a",
volume_mounts=volume_mounts,
)

with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
mock_to_resources.return_value = MagicMock()

with tempfile.TemporaryDirectory() as tmp_dir:
executor.job_dir = tmp_dir
executor.to_task("test_task", ["python", "script.py"])

# Verify fallback to individual params was covered
mock_volume_mount.assert_called_once()
mock_volume_config.assert_called_once()

def test_supports_launcher_transform(self, mock_skypilot_imports):
"""Test that supports_launcher_transform returns True."""
executor = SkypilotExecutor(
container_image="nvcr.io/nvidia/nemo:latest",
cloud="kubernetes",
cluster_name="test-cluster",
)

# Test the method returns True
assert executor.supports_launcher_transform() is True
Loading