Skip to content

Commit 4dcb8fe

Browse files
Merge branch 'main' into profiling
2 parents 04168a8 + e0fa23e commit 4dcb8fe

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ packaging = "*"
3737
python = ">=3.10"
3838
psutil = ">=6.0.0"
3939
pyyaml = "*"
40-
pynvml = ">=12.0.0"
4140
nvidia-ml-py = ">=12.570.86"
4241
defusedxml = "*"
4342

src/nvidia_resiliency_ext/inprocess/rank_assignment.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import warnings
2727
from typing import Callable, Optional, Union
2828

29+
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
30+
2931
from . import exception, utils
3032
from .state import Mode, State
3133
from .store import StoreMixin
@@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
177179
active_rank = None
178180
# Log deactivation if transitioning from ACTIVE to INACTIVE
179181
if state.mode == Mode.ACTIVE:
180-
log = logging.getLogger(__name__)
182+
log = logging.getLogger(LogConfig.name)
183+
181184
log.info(
182185
f"[In-process] Rank deactivated (rank={state.rank}) due to max active world size limit ({active_world_size})"
183186
)
@@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
224227
active_rank = None
225228
# Log deactivation if transitioning from ACTIVE to INACTIVE
226229
if state.mode == Mode.ACTIVE:
227-
log = logging.getLogger(__name__)
230+
log = logging.getLogger(LogConfig.name)
228231
log.info(
229232
f"[In-process] Rank deactivated (rank={state.rank}) due to divisibility requirement (active_world_size={active_world_size}, divisor={divisor})"
230233
)
@@ -349,7 +352,7 @@ def __repr__(self):
349352
return f'{type(self).__name__}({self.name=})'
350353

351354

352-
def bounded_activate(node, counter, path=None):
355+
def bounded_activate(node, counter, path=None, current_state=None):
353356
if path is None:
354357
path = []
355358

@@ -361,17 +364,29 @@ def bounded_activate(node, counter, path=None):
361364
for ascendant in path
362365
)
363366
):
367+
# Log activation if this is the current rank
368+
if current_state and current_state.initial_rank == node.state.initial_rank:
369+
log = logging.getLogger(LogConfig.name)
370+
log.info(
371+
f"[In-process] Rank activated (initial_rank={node.state.initial_rank}, active_rank={counter}) in topology tree"
372+
)
364373
node.activate(counter)
365374
counter += 1
366375
for ascendant in path:
367376
ascendant.active_count += 1
368377
else:
378+
# Log deactivation if this is the current rank
379+
if current_state and current_state.initial_rank == node.state.initial_rank:
380+
log = logging.getLogger(LogConfig.name)
381+
log.info(
382+
f"[In-process] Rank deactivated (initial_rank={node.state.initial_rank}) due to max_ranks constraint in topology layer"
383+
)
369384
node.deactivate()
370385

371386
path.append(node)
372387

373388
for child in node.children.values():
374-
counter = bounded_activate(child, counter, path)
389+
counter = bounded_activate(child, counter, path, current_state)
375390
path.pop()
376391
return counter
377392

@@ -574,7 +589,7 @@ def build_tree(self, state, store):
574589
def replace_with_inactive(self, terminated_active_ranks):
575590
replaced_terminate_active_ranks = set()
576591

577-
log = logging.getLogger(__name__)
592+
log = logging.getLogger(LogConfig.name)
578593

579594
for terminated_active_rank in sorted(terminated_active_ranks):
580595
terminated_active_node = self.rank_map[terminated_active_rank]
@@ -625,7 +640,7 @@ def replace_with_backfill(self, unhandled_terminations):
625640
key=lambda node: node.state.active_rank,
626641
)
627642

628-
log = logging.getLogger(__name__)
643+
log = logging.getLogger(LogConfig.name)
629644
for backfill_node, terminated_node in itertools.zip_longest(
630645
reversed(largest_active_nodes),
631646
terminated_nodes,
@@ -647,7 +662,7 @@ def replace_with_backfill(self, unhandled_terminations):
647662

648663
def shift_ranks(self, replaced_active, unhandled_terminations):
649664
sorted_replaced_active = sorted(replaced_active)
650-
log = logging.getLogger(__name__)
665+
log = logging.getLogger(LogConfig.name)
651666

652667
for n in self.rank_map.values():
653668
n.state.active_world_size -= len(unhandled_terminations)
@@ -672,7 +687,7 @@ def filter_active_world_size(self):
672687
new_active_world_size = self.world_size_filter(active_world_size)
673688
assert new_active_world_size <= active_world_size
674689

675-
log = logging.getLogger(__name__)
690+
log = logging.getLogger(LogConfig.name)
676691
for leaf in self.tree.iter_leaves():
677692
leaf.state.active_world_size = new_active_world_size
678693
if leaf.state.mode == Mode.ACTIVE and leaf.state.active_rank >= new_active_world_size:
@@ -722,7 +737,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
722737
if self.tree is None:
723738
self.build_tree(state, store)
724739

725-
active_world_size = bounded_activate(self.tree, 0)
740+
active_world_size = bounded_activate(self.tree, 0, None, self.current_state)
726741
for node in self.rank_map.values():
727742
node.state.active_world_size = active_world_size
728743

@@ -738,7 +753,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
738753
rank for rank in terminated_ranks if self.rank_map[rank].state.mode == Mode.ACTIVE
739754
)
740755

741-
log = logging.getLogger(__name__)
756+
log = logging.getLogger(LogConfig.name)
742757
for terminated_rank in terminated_ranks:
743758
# If this rank is being terminated, log it
744759
if self.current_state.initial_rank == self.rank_map[terminated_rank].state.initial_rank:
@@ -808,7 +823,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
808823
terminated_ranks = utils.format_rank_set(terminated_ranks)
809824
raise RankDiscarded(f'{rank=} {terminated_ranks=}')
810825
elif rank >= world_size:
811-
log = logging.getLogger(__name__)
826+
log = logging.getLogger(LogConfig.name)
812827
old_rank = rank
813828
rank = ordered_terminated_ranks[rank - world_size]
814829
log.info(
@@ -869,7 +884,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
869884
old_rank = rank
870885
rank = rank - sum(rank > terminated_rank for terminated_rank in terminated_ranks)
871886
if old_rank != rank:
872-
log = logging.getLogger(__name__)
887+
log = logging.getLogger(LogConfig.name)
873888
log.info(f"[In-process] Rank shifted (rank changed from {old_rank} to {rank})")
874889

875890
state = dataclasses.replace(
@@ -982,7 +997,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
982997

983998
group_count = int(store.get(prefixed_key))
984999
if not self.condition(group_count):
985-
log = logging.getLogger(__name__)
1000+
log = logging.getLogger(LogConfig.name)
9861001
log.info(
9871002
f"[In-process] Rank marked for termination (rank={rank}, group_key={key}, group_count={group_count}) due to failed group condition"
9881003
)

0 commit comments

Comments
 (0)