26
26
import warnings
27
27
from typing import Callable , Optional , Union
28
28
29
+ from nvidia_resiliency_ext .shared_utils .log_manager import LogConfig
30
+
29
31
from . import exception , utils
30
32
from .state import Mode , State
31
33
from .store import StoreMixin
@@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
177
179
active_rank = None
178
180
# Log deactivation if transitioning from ACTIVE to INACTIVE
179
181
if state .mode == Mode .ACTIVE :
180
- log = logging .getLogger (__name__ )
182
+ log = logging .getLogger (LogConfig .name )
183
+
181
184
log .info (
182
185
f"[In-process] Rank deactivated (rank={ state .rank } ) due to max active world size limit ({ active_world_size } )"
183
186
)
@@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
224
227
active_rank = None
225
228
# Log deactivation if transitioning from ACTIVE to INACTIVE
226
229
if state .mode == Mode .ACTIVE :
227
- log = logging .getLogger (__name__ )
230
+ log = logging .getLogger (LogConfig . name )
228
231
log .info (
229
232
f"[In-process] Rank deactivated (rank={ state .rank } ) due to divisibility requirement (active_world_size={ active_world_size } , divisor={ divisor } )"
230
233
)
@@ -349,7 +352,7 @@ def __repr__(self):
349
352
return f'{ type (self ).__name__ } ({ self .name = } )'
350
353
351
354
352
- def bounded_activate (node , counter , path = None ):
355
+ def bounded_activate (node , counter , path = None , current_state = None ):
353
356
if path is None :
354
357
path = []
355
358
@@ -361,17 +364,29 @@ def bounded_activate(node, counter, path=None):
361
364
for ascendant in path
362
365
)
363
366
):
367
+ # Log activation if this is the current rank
368
+ if current_state and current_state .initial_rank == node .state .initial_rank :
369
+ log = logging .getLogger (LogConfig .name )
370
+ log .info (
371
+ f"[In-process] Rank activated (initial_rank={ node .state .initial_rank } , active_rank={ counter } ) in topology tree"
372
+ )
364
373
node .activate (counter )
365
374
counter += 1
366
375
for ascendant in path :
367
376
ascendant .active_count += 1
368
377
else :
378
+ # Log deactivation if this is the current rank
379
+ if current_state and current_state .initial_rank == node .state .initial_rank :
380
+ log = logging .getLogger (LogConfig .name )
381
+ log .info (
382
+ f"[In-process] Rank deactivated (initial_rank={ node .state .initial_rank } ) due to max_ranks constraint in topology layer"
383
+ )
369
384
node .deactivate ()
370
385
371
386
path .append (node )
372
387
373
388
for child in node .children .values ():
374
- counter = bounded_activate (child , counter , path )
389
+ counter = bounded_activate (child , counter , path , current_state )
375
390
path .pop ()
376
391
return counter
377
392
@@ -574,7 +589,7 @@ def build_tree(self, state, store):
574
589
def replace_with_inactive (self , terminated_active_ranks ):
575
590
replaced_terminate_active_ranks = set ()
576
591
577
- log = logging .getLogger (__name__ )
592
+ log = logging .getLogger (LogConfig . name )
578
593
579
594
for terminated_active_rank in sorted (terminated_active_ranks ):
580
595
terminated_active_node = self .rank_map [terminated_active_rank ]
@@ -625,7 +640,7 @@ def replace_with_backfill(self, unhandled_terminations):
625
640
key = lambda node : node .state .active_rank ,
626
641
)
627
642
628
- log = logging .getLogger (__name__ )
643
+ log = logging .getLogger (LogConfig . name )
629
644
for backfill_node , terminated_node in itertools .zip_longest (
630
645
reversed (largest_active_nodes ),
631
646
terminated_nodes ,
@@ -647,7 +662,7 @@ def replace_with_backfill(self, unhandled_terminations):
647
662
648
663
def shift_ranks (self , replaced_active , unhandled_terminations ):
649
664
sorted_replaced_active = sorted (replaced_active )
650
- log = logging .getLogger (__name__ )
665
+ log = logging .getLogger (LogConfig . name )
651
666
652
667
for n in self .rank_map .values ():
653
668
n .state .active_world_size -= len (unhandled_terminations )
@@ -672,7 +687,7 @@ def filter_active_world_size(self):
672
687
new_active_world_size = self .world_size_filter (active_world_size )
673
688
assert new_active_world_size <= active_world_size
674
689
675
- log = logging .getLogger (__name__ )
690
+ log = logging .getLogger (LogConfig . name )
676
691
for leaf in self .tree .iter_leaves ():
677
692
leaf .state .active_world_size = new_active_world_size
678
693
if leaf .state .mode == Mode .ACTIVE and leaf .state .active_rank >= new_active_world_size :
@@ -722,7 +737,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
722
737
if self .tree is None :
723
738
self .build_tree (state , store )
724
739
725
- active_world_size = bounded_activate (self .tree , 0 )
740
+ active_world_size = bounded_activate (self .tree , 0 , None , self . current_state )
726
741
for node in self .rank_map .values ():
727
742
node .state .active_world_size = active_world_size
728
743
@@ -738,7 +753,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
738
753
rank for rank in terminated_ranks if self .rank_map [rank ].state .mode == Mode .ACTIVE
739
754
)
740
755
741
- log = logging .getLogger (__name__ )
756
+ log = logging .getLogger (LogConfig . name )
742
757
for terminated_rank in terminated_ranks :
743
758
# If this rank is being terminated, log it
744
759
if self .current_state .initial_rank == self .rank_map [terminated_rank ].state .initial_rank :
@@ -808,7 +823,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
808
823
terminated_ranks = utils .format_rank_set (terminated_ranks )
809
824
raise RankDiscarded (f'{ rank = } { terminated_ranks = } ' )
810
825
elif rank >= world_size :
811
- log = logging .getLogger (__name__ )
826
+ log = logging .getLogger (LogConfig . name )
812
827
old_rank = rank
813
828
rank = ordered_terminated_ranks [rank - world_size ]
814
829
log .info (
@@ -869,7 +884,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
869
884
old_rank = rank
870
885
rank = rank - sum (rank > terminated_rank for terminated_rank in terminated_ranks )
871
886
if old_rank != rank :
872
- log = logging .getLogger (__name__ )
887
+ log = logging .getLogger (LogConfig . name )
873
888
log .info (f"[In-process] Rank shifted (rank changed from { old_rank } to { rank } )" )
874
889
875
890
state = dataclasses .replace (
@@ -982,7 +997,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
982
997
983
998
group_count = int (store .get (prefixed_key ))
984
999
if not self .condition (group_count ):
985
- log = logging .getLogger (__name__ )
1000
+ log = logging .getLogger (LogConfig . name )
986
1001
log .info (
987
1002
f"[In-process] Rank marked for termination (rank={ rank } , group_key={ key } , group_count={ group_count } ) due to failed group condition"
988
1003
)
0 commit comments