Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig

from ..shared_utils.health_check import GPUHealthCheck
from ..shared_utils.profiling import ProfilingEvent, record_profiling_event
from .data import WorkloadAction
from .ipc_connector import IpcConnector
from .launcher import FT_LAUNCHER_IPC_SOCKET, UnhealthyNodeException
Expand Down Expand Up @@ -1322,6 +1323,13 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
self._record(message=msg)
log.info(msg)

# Record rendezvous start event
rendezvous_start_event_id = record_profiling_event(
ProfilingEvent.RENDEZVOUS_STARTED,
node_id=self._this_node.addr,
metadata={"run_id": self._settings.run_id, "round": self._state_holder.state.round},
)

try:
self._stop_heartbeats()

Expand Down Expand Up @@ -1362,6 +1370,17 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
self._record(message=msg, rank=rank)
log.info(msg)

# Record rendezvous completion event
rendezvous_completion_event_id = record_profiling_event(
ProfilingEvent.RENDEZVOUS_COMPLETED,
node_id=self._this_node.addr,
metadata={
"run_id": self._settings.run_id,
"round": self._state_holder.state.round,
"world_size": world_size,
},
)

# Use RendezvousInfo if available (newer PyTorch versions >= 2.4.0)
# Fall back to tuple format if RendezvousInfo is not supported
if _RENDEZVOUS_INFO_AVAILABLE:
Expand Down
57 changes: 54 additions & 3 deletions src/nvidia_resiliency_ext/fault_tolerance/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@
write_obj_to_ipc_stream,
)
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger
from nvidia_resiliency_ext.shared_utils.profiling import (
ProfilingEvent,
log_current_cycle_summary,
record_profiling_event,
)

# Deprecation warning for FT_LAUNCHER_LOGLEVEL
if os.getenv('FT_LAUNCHER_LOGLEVEL') is not None:
Expand Down Expand Up @@ -138,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent):
python multiprocessing compatible. To pass multiprocessing data structures
to the workers you may create the data structure in the same multiprocessing
context as the specified ``start_method`` and pass it as a function argument.

Note: If your training script uses the nvrx logger, make sure to call
``setup_logger()`` at the beginning of your training function to ensure
the logger is properly set up in each subprocess.
Expand Down Expand Up @@ -179,12 +184,12 @@ def trainer(args) -> str:
# Ensure nvrx logger is set up in this subprocess
from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger
setup_logger()

# Use the nvrx logger
import logging
logger = logging.getLogger(LogConfig.name)
logger.info("Training started")

return "do train"

def main():
Expand Down Expand Up @@ -251,6 +256,7 @@ def __init__(
self._ft_cfg = fault_tol_cfg
self._children_pgids: Set[int] = set()
self._restart_policy = restart_policy
self._node_id = self._get_fq_hostname()

DEFAULT_ROLE = "default" # FIXME

Expand Down Expand Up @@ -322,6 +328,14 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# Record failure detection event
record_profiling_event(
ProfilingEvent.FAILURE_DETECTED,
node_id=self._node_id,
rank=self._worker_group.group_rank,
metadata={"state": state.name, "role": role}
)

if self._remaining_restarts > 0:
logger.info(
"[%s] Worker group %s. "
Expand All @@ -347,6 +361,14 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
# Record failure detection event
record_profiling_event(
ProfilingEvent.FAILURE_DETECTED,
node_id=self._node_id,
rank=self._worker_group.group_rank,
metadata={"state": state.name, "role": role}
)

logger.info(
"[%s] Detected %s "
"new nodes from group_rank=%s; "
Expand Down Expand Up @@ -587,6 +609,14 @@ async def send_close_msg():

self._shutdown(timeout=self._workers_stop_timeout)

# Record worker termination event after shutdown is complete
record_profiling_event(
ProfilingEvent.WORKER_TERMINATED,
node_id=self._node_id,
rank=worker_group.group_rank,
metadata={"group_rank": worker_group.group_rank, "world_size": worker_group.group_world_size}
)

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
Expand All @@ -596,6 +626,14 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
assert store is not None
restart_count = spec.max_restarts - self._remaining_restarts

# Record worker start start event
start_start_event_id = record_profiling_event(
ProfilingEvent.WORKER_START_STARTED,
node_id=self._node_id,
rank=worker_group.group_rank,
metadata={"group_rank": worker_group.group_rank, "world_size": worker_group.group_world_size}
)

use_agent_store = spec.rdzv_handler.use_agent_store

args: Dict[int, Tuple] = {}
Expand Down Expand Up @@ -667,8 +705,20 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:

self._children_pgids = {os.getpgid(p) for p in self._pcontext.pids().values()}

# Record worker start completion event
start_completion_event_id = record_profiling_event(
ProfilingEvent.WORKER_START_COMPLETED,
node_id=self._node_id,
rank=worker_group.group_rank,
metadata={"group_rank": worker_group.group_rank, "world_size": worker_group.group_world_size}
)

# Log profiling summary for this restart cycle
log_current_cycle_summary()

return self._pcontext.pids()


def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 30) -> None:
if self._worker_watchdog is not None:
self._worker_watchdog.stop()
Expand Down Expand Up @@ -1054,6 +1104,7 @@ def launch_agent(
)

logger.info(f"Agent .run() is OK. No failures in the result. {result=}")

return result.return_values
except UnhealthyNodeException as e:
# do not shutdown rendezvous when an unhealthy node is leaving
Expand Down
Loading