Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ requires = ["poetry-core>=1.0.0", "pybind11", "setuptools", "wheel"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.dependencies]
nv-one-logger-core = ">=2.1.0"
torch = ">=2.3.0"
packaging = "*"
python = ">=3.10"
Expand Down
13 changes: 13 additions & 0 deletions src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig

from ..shared_utils.health_check import GPUHealthCheck
from ..shared_utils.profiling import ProfilingEvent, record_profiling_event
from .data import WorkloadAction
from .ipc_connector import IpcConnector
from .launcher import FT_LAUNCHER_IPC_SOCKET, UnhealthyNodeException
Expand Down Expand Up @@ -1322,6 +1323,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
self._record(message=msg)
log.info(msg)

# Record rendezvous start event
rendezvous_start_event_id = record_profiling_event(
ProfilingEvent.RENDEZVOUS_STARTED,
node_id=self._this_node,
)

try:
self._stop_heartbeats()

Expand Down Expand Up @@ -1362,6 +1369,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
self._record(message=msg, rank=rank)
log.info(msg)

# Record rendezvous completion event
rendezvous_completion_event_id = record_profiling_event(
ProfilingEvent.RENDEZVOUS_COMPLETED,
node_id=self._this_node,
)

# Use RendezvousInfo if available (newer PyTorch versions >= 2.4.0)
# Fall back to tuple format if RendezvousInfo is not supported
if _RENDEZVOUS_INFO_AVAILABLE:
Expand Down
45 changes: 42 additions & 3 deletions src/nvidia_resiliency_ext/fault_tolerance/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
write_obj_to_ipc_stream,
)
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger
from nvidia_resiliency_ext.shared_utils.profiling import ProfilingEvent, record_profiling_event

# Deprecation warning for FT_LAUNCHER_LOGLEVEL
if os.getenv('FT_LAUNCHER_LOGLEVEL') is not None:
Expand Down Expand Up @@ -142,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent):
python multiprocessing compatible. To pass multiprocessing data structures
to the workers you may create the data structure in the same multiprocessing
context as the specified ``start_method`` and pass it as a function argument.

Note: If your training script uses the nvrx logger, make sure to call
``setup_logger()`` at the beginning of your training function to ensure
the logger is properly set up in each subprocess.
Expand Down Expand Up @@ -183,12 +184,12 @@ def trainer(args) -> str:
# Ensure nvrx logger is set up in this subprocess
from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger
setup_logger()

# Use the nvrx logger
import logging
logger = logging.getLogger(LogConfig.name)
logger.info("Training started")

return "do train"

def main():
Expand Down Expand Up @@ -255,6 +256,7 @@ def __init__(
self._ft_cfg = fault_tol_cfg
self._children_pgids: Set[int] = set()
self._restart_policy = restart_policy
self._node_id = self._get_fq_hostname()

DEFAULT_ROLE = "default" # FIXME

Expand Down Expand Up @@ -326,6 +328,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# Record failure detection event
record_profiling_event(
ProfilingEvent.FAILURE_DETECTED,
node_id=self._rdzv_handler._this_node,
rank=self._worker_group.group_rank,
)

if self._remaining_restarts > 0:
logger.info(
"[%s] Worker group %s. "
Expand All @@ -351,6 +360,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
# Record failure detection event
record_profiling_event(
ProfilingEvent.FAILURE_DETECTED,
node_id=self._rdzv_handler._this_node,
rank=self._worker_group.group_rank,
)

logger.info(
"[%s] Detected %s "
"new nodes from group_rank=%s; "
Expand Down Expand Up @@ -591,6 +607,13 @@ async def send_close_msg():

self._shutdown(timeout=self._workers_stop_timeout)

# Record worker termination event after shutdown is complete
record_profiling_event(
ProfilingEvent.WORKER_TERMINATED,
node_id=self._rdzv_handler._this_node,
rank=worker_group.group_rank,
)

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
Expand All @@ -600,6 +623,13 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
assert store is not None
restart_count = spec.max_restarts - self._remaining_restarts

# Record worker start start event
record_profiling_event(
ProfilingEvent.WORKER_START_STARTED,
node_id=self._rdzv_handler._this_node,
rank=worker_group.group_rank,
)

use_agent_store = spec.rdzv_handler.use_agent_store

args: Dict[int, Tuple] = {}
Expand Down Expand Up @@ -671,8 +701,16 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:

self._children_pgids = {os.getpgid(p) for p in self._pcontext.pids().values()}

# Record worker start completion event
record_profiling_event(
ProfilingEvent.WORKER_START_COMPLETED,
node_id=self._rdzv_handler._this_node,
rank=worker_group.group_rank,
)

return self._pcontext.pids()


def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 30) -> None:
if self._worker_watchdog is not None:
self._worker_watchdog.stop()
Expand Down Expand Up @@ -1058,6 +1096,7 @@ def launch_agent(
)

logger.info(f"Agent .run() is OK. No failures in the result. {result=}")

return result.return_values
except UnhealthyNodeException as e:
# do not shutdown rendezvous when an unhealthy node is leaving
Expand Down
117 changes: 117 additions & 0 deletions src/nvidia_resiliency_ext/shared_utils/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file adds time profiling capabilities using nv one logger

import logging
import time
from datetime import datetime, timezone
from enum import Enum
from typing import Optional

from nv_one_logger.api.one_logger_provider import OneLoggerProvider
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hexinw-nvidia does it mean you are relying on application for instantiation and configuration?

from nv_one_logger.core.attributes import Attributes
from nv_one_logger.core.event import Event

from ..shared_utils.log_manager import LogConfig


class ProfilingEvent(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we are creating a custom profiler and not using something standard or OneLogger?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using the PyTorch metrics system for the profiling. The NV OneLogger seems to have duplicate functions as the PyTorch metrics. There is no particular reason that we go with PyTorch metrics. Mainly, NVRx is tightly integrated with PyTorch.

Copy link
Contributor

@apaithankar apaithankar Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Megatron-LM and Nemo are both using NV OneLogger and it seems it is better integrated into other services (log/event/metrics collecting), would using NV OneLogger not make it easier to integrate when we have to create a dashboard in FACT?

Not opposed to using TorchMetrics, just looking at the ecosystem for integration

"""Enumeration of profiling events for fault tolerance metrics."""

FAILURE_DETECTED = "failure_detected"
WORKER_TERMINATED = "worker_terminated"
RENDEZVOUS_STARTED = "rendezvous_started"
RENDEZVOUS_COMPLETED = "rendezvous_completed"
WORKER_START_STARTED = "worker_start_started"
WORKER_START_COMPLETED = "worker_start_completed"


class FaultToleranceProfiler:
"""Profiler for measuring fault tolerance timing metrics using nv one logger."""

def __init__(self):
self._current_cycle = 0
# Initialize logger as a member to avoid module-level logger issues
self._logger = logging.getLogger(LogConfig.name)

def _timestamp_to_utc_datetime(self, timestamp: float) -> str:
"""Convert timestamp to UTC datetime string."""
utc_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc)
return utc_datetime.strftime("%Y-%m-%d %H:%M:%S.%f")[
:-3
] # Remove last 3 digits for milliseconds

def _publish_metrics(
self, event: ProfilingEvent, timestamp: float, node_id: Optional[str], rank: Optional[int]
) -> None:
"""Publish metrics using nv one logger."""
try:
# Check if nv one logger is available and enabled
if OneLoggerProvider.instance().one_logger_enabled:
# Create attributes for the event
attributes = Attributes()
attributes.add("event_type", event.value)
attributes.add("timestamp_ms", int(timestamp * 1000))
attributes.add("cycle", self._current_cycle)
if node_id:
attributes.add("node_id", node_id)
if rank is not None:
attributes.add("rank", rank)

# Create and record the event
event_obj = Event.create(f"ft.{event.value}", attributes)
OneLoggerProvider.instance().recorder.event(None, event_obj)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hexinw-nvidia Note that in nv-one-logger API design, event need to belong to one span e.g., application span or any custom span. the Recorder.start can create one and recorder will record the start and end timestamps of it.

except Exception as e:
# If nv one logger fails, just log a warning and continue
self._logger.warning(f"Failed to publish metrics to nv one logger: {e}")

def record_event(
self,
event: ProfilingEvent,
node_id: Optional[str] = None,
rank: Optional[int] = None,
) -> str:
"""Record a profiling event and return a unique event ID."""
timestamp = time.time()
event_id = f"{event.value}_{timestamp}_{node_id or 'unknown'}_{rank or 'unknown'}"

# Increment cycle count for failure detection events
if event == ProfilingEvent.FAILURE_DETECTED:
self._current_cycle += 1

# Publish metrics using nv one logger
self._publish_metrics(event, timestamp, node_id, rank)

# Format log message with cycle count and UTC time
utc_time = self._timestamp_to_utc_datetime(timestamp)
self._logger.info(
f" - Cycle: {self._current_cycle} Event: {event.value} Node: {node_id} Rank: {rank} "
f"Time: {utc_time} UTC"
)
return event_id


# Global profiler instance
_global_profiler = FaultToleranceProfiler()


def record_profiling_event(
event: ProfilingEvent,
node_id: Optional[str] = None,
rank: Optional[int] = None,
) -> str:
"""Convenience function to record a profiling event."""
return _global_profiler.record_event(event, node_id, rank)