Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/nvidia_resiliency_ext/inprocess/rank_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import warnings
from typing import Callable, Optional, Union

from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig

from . import exception, utils
from .state import Mode, State
from .store import StoreMixin
Expand Down Expand Up @@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
active_rank = None
# Log deactivation if transitioning from ACTIVE to INACTIVE
if state.mode == Mode.ACTIVE:
log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)

log.info(
f"[In-process] Rank deactivated (rank={state.rank}) due to max active world size limit ({active_world_size})"
)
Expand Down Expand Up @@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
active_rank = None
# Log deactivation if transitioning from ACTIVE to INACTIVE
if state.mode == Mode.ACTIVE:
log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
log.info(
f"[In-process] Rank deactivated (rank={state.rank}) due to divisibility requirement (active_world_size={active_world_size}, divisor={divisor})"
)
Expand Down Expand Up @@ -349,7 +352,7 @@ def __repr__(self):
return f'{type(self).__name__}({self.name=})'


def bounded_activate(node, counter, path=None):
def bounded_activate(node, counter, path=None, current_state=None):
if path is None:
path = []

Expand All @@ -361,17 +364,29 @@ def bounded_activate(node, counter, path=None):
for ascendant in path
)
):
# Log activation if this is the current rank
if current_state and current_state.initial_rank == node.state.initial_rank:
log = logging.getLogger(LogConfig.name)
log.info(
f"[In-process] Rank activated (initial_rank={node.state.initial_rank}, active_rank={counter}) in topology tree"
)
node.activate(counter)
counter += 1
for ascendant in path:
ascendant.active_count += 1
else:
# Log deactivation if this is the current rank
if current_state and current_state.initial_rank == node.state.initial_rank:
log = logging.getLogger(LogConfig.name)
log.info(
f"[In-process] Rank deactivated (initial_rank={node.state.initial_rank}) due to max_ranks constraint in topology layer"
)
node.deactivate()

path.append(node)

for child in node.children.values():
counter = bounded_activate(child, counter, path)
counter = bounded_activate(child, counter, path, current_state)
path.pop()
return counter

Expand Down Expand Up @@ -574,7 +589,7 @@ def build_tree(self, state, store):
def replace_with_inactive(self, terminated_active_ranks):
replaced_terminate_active_ranks = set()

log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)

for terminated_active_rank in sorted(terminated_active_ranks):
terminated_active_node = self.rank_map[terminated_active_rank]
Expand Down Expand Up @@ -625,7 +640,7 @@ def replace_with_backfill(self, unhandled_terminations):
key=lambda node: node.state.active_rank,
)

log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
for backfill_node, terminated_node in itertools.zip_longest(
reversed(largest_active_nodes),
terminated_nodes,
Expand All @@ -647,7 +662,7 @@ def replace_with_backfill(self, unhandled_terminations):

def shift_ranks(self, replaced_active, unhandled_terminations):
sorted_replaced_active = sorted(replaced_active)
log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)

for n in self.rank_map.values():
n.state.active_world_size -= len(unhandled_terminations)
Expand All @@ -672,7 +687,7 @@ def filter_active_world_size(self):
new_active_world_size = self.world_size_filter(active_world_size)
assert new_active_world_size <= active_world_size

log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
for leaf in self.tree.iter_leaves():
leaf.state.active_world_size = new_active_world_size
if leaf.state.mode == Mode.ACTIVE and leaf.state.active_rank >= new_active_world_size:
Expand Down Expand Up @@ -722,7 +737,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
if self.tree is None:
self.build_tree(state, store)

active_world_size = bounded_activate(self.tree, 0)
active_world_size = bounded_activate(self.tree, 0, None, self.current_state)
for node in self.rank_map.values():
node.state.active_world_size = active_world_size

Expand All @@ -738,7 +753,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
rank for rank in terminated_ranks if self.rank_map[rank].state.mode == Mode.ACTIVE
)

log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
for terminated_rank in terminated_ranks:
# If this rank is being terminated, log it
if self.current_state.initial_rank == self.rank_map[terminated_rank].state.initial_rank:
Expand Down Expand Up @@ -808,7 +823,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
terminated_ranks = utils.format_rank_set(terminated_ranks)
raise RankDiscarded(f'{rank=} {terminated_ranks=}')
elif rank >= world_size:
log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
old_rank = rank
rank = ordered_terminated_ranks[rank - world_size]
log.info(
Expand Down Expand Up @@ -869,7 +884,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
old_rank = rank
rank = rank - sum(rank > terminated_rank for terminated_rank in terminated_ranks)
if old_rank != rank:
log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
log.info(f"[In-process] Rank shifted (rank changed from {old_rank} to {rank})")

state = dataclasses.replace(
Expand Down Expand Up @@ -982,7 +997,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:

group_count = int(store.get(prefixed_key))
if not self.condition(group_count):
log = logging.getLogger(__name__)
log = logging.getLogger(LogConfig.name)
log.info(
f"[In-process] Rank marked for termination (rank={rank}, group_key={key}, group_count={group_count}) due to failed group condition"
)
Expand Down