diff --git a/pyproject.toml b/pyproject.toml index bfb86b02..eda553a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ requires = ["poetry-core>=1.0.0", "pybind11", "setuptools", "wheel"] build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] +nv-one-logger-core = ">=2.1.0" torch = ">=2.3.0" packaging = "*" python = ">=3.10" diff --git a/src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py b/src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py index e0579422..c2ee1805 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py @@ -60,6 +60,7 @@ from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig from ..shared_utils.health_check import GPUHealthCheck +from ..shared_utils.profiling import ProfilingEvent, record_profiling_event from .data import WorkloadAction from .ipc_connector import IpcConnector from .launcher import FT_LAUNCHER_IPC_SOCKET, UnhealthyNodeException @@ -1322,6 +1323,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]: self._record(message=msg) log.info(msg) + # Record rendezvous start event + rendezvous_start_event_id = record_profiling_event( + ProfilingEvent.RENDEZVOUS_STARTED, + node_id=self._this_node, + ) + try: self._stop_heartbeats() @@ -1362,6 +1369,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]: self._record(message=msg, rank=rank) log.info(msg) + # Record rendezvous completion event + rendezvous_completion_event_id = record_profiling_event( + ProfilingEvent.RENDEZVOUS_COMPLETED, + node_id=self._this_node, + ) + # Use RendezvousInfo if available (newer PyTorch versions >= 2.4.0) # Fall back to tuple format if RendezvousInfo is not supported if _RENDEZVOUS_INFO_AVAILABLE: diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index 61d4f340..ba81f836 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -76,6 +76,7 @@ write_obj_to_ipc_stream, ) from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger +from nvidia_resiliency_ext.shared_utils.profiling import ProfilingEvent, record_profiling_event # Deprecation warning for FT_LAUNCHER_LOGLEVEL if os.getenv('FT_LAUNCHER_LOGLEVEL') is not None: @@ -142,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent): python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. - + Note: If your training script uses the nvrx logger, make sure to call ``setup_logger()`` at the beginning of your training function to ensure the logger is properly set up in each subprocess. @@ -183,12 +184,12 @@ def trainer(args) -> str: # Ensure nvrx logger is set up in this subprocess from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger setup_logger() - + # Use the nvrx logger import logging logger = logging.getLogger(LogConfig.name) logger.info("Training started") - + return "do train" def main(): @@ -255,6 +256,7 @@ def __init__( self._ft_cfg = fault_tol_cfg self._children_pgids: Set[int] = set() self._restart_policy = restart_policy + self._node_id = self._get_fq_hostname() DEFAULT_ROLE = "default" # FIXME @@ -326,6 +328,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes self._exit_barrier() return run_result elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: + # Record failure detection event + record_profiling_event( + ProfilingEvent.FAILURE_DETECTED, + node_id=self._rdzv_handler._this_node, + rank=self._worker_group.group_rank, + ) + if self._remaining_restarts > 0: logger.info( "[%s] Worker group %s. " @@ -351,6 +360,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes num_nodes_waiting = rdzv_handler.num_nodes_waiting() group_rank = self._worker_group.group_rank if num_nodes_waiting > 0: + # Record failure detection event + record_profiling_event( + ProfilingEvent.FAILURE_DETECTED, + node_id=self._rdzv_handler._this_node, + rank=self._worker_group.group_rank, + ) + logger.info( "[%s] Detected %s " "new nodes from group_rank=%s; " @@ -591,6 +607,13 @@ async def send_close_msg(): self._shutdown(timeout=self._workers_stop_timeout) + # Record worker termination event after shutdown is complete + record_profiling_event( + ProfilingEvent.WORKER_TERMINATED, + node_id=self._rdzv_handler._this_node, + rank=worker_group.group_rank, + ) + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof @@ -600,6 +623,13 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: assert store is not None restart_count = spec.max_restarts - self._remaining_restarts + # Record worker start start event + record_profiling_event( + ProfilingEvent.WORKER_START_STARTED, + node_id=self._rdzv_handler._this_node, + rank=worker_group.group_rank, + ) + use_agent_store = spec.rdzv_handler.use_agent_store args: Dict[int, Tuple] = {} @@ -671,8 +701,16 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: self._children_pgids = {os.getpgid(p) for p in self._pcontext.pids().values()} + # Record worker start completion event + record_profiling_event( + ProfilingEvent.WORKER_START_COMPLETED, + node_id=self._rdzv_handler._this_node, + rank=worker_group.group_rank, + ) + return self._pcontext.pids() + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 30) -> None: if self._worker_watchdog is not None: self._worker_watchdog.stop() @@ -1058,6 +1096,7 @@ def launch_agent( ) logger.info(f"Agent .run() is OK. No failures in the result. {result=}") + return result.return_values except UnhealthyNodeException as e: # do not shutdown rendezvous when an unhealthy node is leaving diff --git a/src/nvidia_resiliency_ext/shared_utils/profiling.py b/src/nvidia_resiliency_ext/shared_utils/profiling.py new file mode 100644 index 00000000..b3688421 --- /dev/null +++ b/src/nvidia_resiliency_ext/shared_utils/profiling.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file adds time profiling capabilities using nv one logger + +import logging +import time +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from nv_one_logger.api.one_logger_provider import OneLoggerProvider +from nv_one_logger.core.attributes import Attributes +from nv_one_logger.core.event import Event + +from ..shared_utils.log_manager import LogConfig + + +class ProfilingEvent(Enum): + """Enumeration of profiling events for fault tolerance metrics.""" + + FAILURE_DETECTED = "failure_detected" + WORKER_TERMINATED = "worker_terminated" + RENDEZVOUS_STARTED = "rendezvous_started" + RENDEZVOUS_COMPLETED = "rendezvous_completed" + WORKER_START_STARTED = "worker_start_started" + WORKER_START_COMPLETED = "worker_start_completed" + + +class FaultToleranceProfiler: + """Profiler for measuring fault tolerance timing metrics using nv one logger.""" + + def __init__(self): + self._current_cycle = 0 + # Initialize logger as a member to avoid module-level logger issues + self._logger = logging.getLogger(LogConfig.name) + + def _timestamp_to_utc_datetime(self, timestamp: float) -> str: + """Convert timestamp to UTC datetime string.""" + utc_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc) + return utc_datetime.strftime("%Y-%m-%d %H:%M:%S.%f")[ + :-3 + ] # Remove last 3 digits for milliseconds + + def _publish_metrics( + self, event: ProfilingEvent, timestamp: float, node_id: Optional[str], rank: Optional[int] + ) -> None: + """Publish metrics using nv one logger.""" + try: + # Check if nv one logger is available and enabled + if OneLoggerProvider.instance().one_logger_enabled: + # Create attributes for the event + attributes = Attributes() + attributes.add("event_type", event.value) + attributes.add("timestamp_ms", int(timestamp * 1000)) + attributes.add("cycle", self._current_cycle) + if node_id: + attributes.add("node_id", node_id) + if rank is not None: + attributes.add("rank", rank) + + # Create and record the event + event_obj = Event.create(f"ft.{event.value}", attributes) + OneLoggerProvider.instance().recorder.event(None, event_obj) + except Exception as e: + # If nv one logger fails, just log a warning and continue + self._logger.warning(f"Failed to publish metrics to nv one logger: {e}") + + def record_event( + self, + event: ProfilingEvent, + node_id: Optional[str] = None, + rank: Optional[int] = None, + ) -> str: + """Record a profiling event and return a unique event ID.""" + timestamp = time.time() + event_id = f"{event.value}_{timestamp}_{node_id or 'unknown'}_{rank or 'unknown'}" + + # Increment cycle count for failure detection events + if event == ProfilingEvent.FAILURE_DETECTED: + self._current_cycle += 1 + + # Publish metrics using nv one logger + self._publish_metrics(event, timestamp, node_id, rank) + + # Format log message with cycle count and UTC time + utc_time = self._timestamp_to_utc_datetime(timestamp) + self._logger.info( + f" - Cycle: {self._current_cycle} Event: {event.value} Node: {node_id} Rank: {rank} " + f"Time: {utc_time} UTC" + ) + return event_id + + +# Global profiler instance +_global_profiler = FaultToleranceProfiler() + + +def record_profiling_event( + event: ProfilingEvent, + node_id: Optional[str] = None, + rank: Optional[int] = None, +) -> str: + """Convenience function to record a profiling event.""" + return _global_profiler.record_event(event, node_id, rank)